Skip to content

Commit

Permalink
feat: client refactor (#183)
Browse files Browse the repository at this point in the history
* feat(client): add an interface for database queries handler

* feat(client): add a postgres database querier that implements database querier interface

* feat(client): add mysql struct that implements database querier

* feat(client): add a sqlite struct that implements the database querier interface

* feat(client): add the database querier and its implementations to the client

* test(client): add some validations to the test
  • Loading branch information
danvergara authored Aug 15, 2023
1 parent 1b43eac commit b1eefd1
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 137 deletions.
182 changes: 50 additions & 132 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
package client

import (
"errors"
"fmt"

sq "github.com/Masterminds/squirrel"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
_ "modernc.org/sqlite"

"github.com/danvergara/dblab/pkg/command"
"github.com/danvergara/dblab/pkg/connection"
"github.com/danvergara/dblab/pkg/drivers"
"github.com/danvergara/dblab/pkg/pagination"
"github.com/jmoiron/sqlx"

// mysql driver.
_ "github.com/go-sql-driver/mysql"
// postgres driver.
_ "github.com/lib/pq"
// sqlite driver.
_ "modernc.org/sqlite"
)

// databaseQuerier is an interface that indicates the methods
// a given type has to implement to interact with a database,
// to get specific data.
// This allows us to decouple the client from the database implementation and
// make adding new databases easier.
type databaseQuerier interface {
ShowTables() (string, []interface{}, error)
TableStructure(tableName string) (string, []interface{}, error)
Constraints(tableName string) (string, []interface{}, error)
Indexes(tableName string) (string, []interface{}, error)
}

// Client is used to store the pool of db connection.
type Client struct {
db *sqlx.DB
databaseQuerier databaseQuerier
driver, schema string
paginationManager *pagination.Manager
limit uint
Expand Down Expand Up @@ -51,6 +59,18 @@ func New(opts command.Options) (*Client, error) {
c.schema = opts.Schema
}

// This is where an implementation of databaseQuerier is getting picked up.
switch c.driver {
case drivers.Postgres, drivers.PostgreSQL:
c.databaseQuerier = newPostgres(c.schema)
case drivers.MySQL:
c.databaseQuerier = newMySQL()
case drivers.SQLite:
c.databaseQuerier = newSQLite()
default:
return nil, fmt.Errorf("%s driver not supported", c.driver)
}

switch c.driver {
case drivers.Postgres:
fallthrough
Expand Down Expand Up @@ -209,31 +229,9 @@ func (c *Client) ShowTables() ([]string, error) {

tables := make([]string, 0)

switch c.driver {
case drivers.Postgres:
fallthrough
case drivers.PostgreSQL:
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query, args, err = psql.Select("table_name").
From("information_schema.tables").
Where(sq.Eq{"table_schema": c.schema}).
OrderBy("table_name").
ToSql()
if err != nil {
return nil, err
}

case drivers.MySQL:
query = "SHOW TABLES;"
case drivers.SQLite:
query = `
SELECT
name
FROM
sqlite_schema
WHERE
type ='table' AND
name NOT LIKE 'sqlite_%';`
query, args, err = c.databaseQuerier.ShowTables()
if err != nil {
return nil, err
}

rows, err := c.db.Queryx(query, args...)
Expand Down Expand Up @@ -353,116 +351,36 @@ func (c *Client) tableCount(tableName string) (int, error) {

// tableStructure returns the structure of the table columns.
func (c *Client) tableStructure(tableName string) ([][]string, []string, error) {
var query string

switch c.driver {
case drivers.Postgres:
fallthrough
case drivers.PostgreSQL:
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)

query, args, err := psql.Select(
"c.column_name",
"c.is_nullable",
"c.data_type",
"c.character_maximum_length",
"c.numeric_precision",
"c.numeric_scale",
"c.ordinal_position",
"tc.constraint_type AS pkey",
).
From("information_schema.columns AS c").
LeftJoin(
`information_schema.constraint_column_usage AS ccu
ON c.table_schema = ccu.table_schema
AND c.table_name = ccu.table_name
AND c.column_name = ccu.column_name`,
).
LeftJoin(
`information_schema.table_constraints AS tc
ON ccu.constraint_schema = tc.constraint_schema
AND ccu.constraint_name = tc.constraint_name`,
).
Where(
sq.And{
sq.Eq{"c.table_schema": c.schema},
sq.Eq{"c.table_name": tableName},
},
).
ToSql()
if err != nil {
return nil, nil, err
}
var (
query string
err error
args []interface{}
)

return c.Query(query, args...)
case drivers.MySQL:
query = fmt.Sprintf("DESCRIBE %s;", tableName)
return c.Query(query)
case drivers.SQLite:
query = fmt.Sprintf("PRAGMA table_info(%s);", tableName)
return c.Query(query)
default:
return nil, nil, errors.New("not supported driver")
query, args, err = c.databaseQuerier.TableStructure(tableName)
if err != nil {
return nil, nil, err
}

return c.Query(query, args...)
}

// constraints returns the resultet of from information_schema.table_constraints.
func (c *Client) constraints(tableName string) ([][]string, []string, error) {
var (
query sq.SelectBuilder
sql string
)

query = sq.Select(
`tc.constraint_name`,
`tc.table_name`,
`tc.constraint_type`,
).
From("information_schema.table_constraints AS tc").
Where("tc.table_name = ?")

switch c.driver {
case drivers.SQLite:
sql = `
SELECT *
FROM
sqlite_master
WHERE
type='table' AND name = ?;`
return c.Query(sql, tableName)
case drivers.Postgres:
fallthrough
case drivers.PostgreSQL:
query = query.Where(fmt.Sprintf("tc.table_schema = '%s'", c.schema))
query = query.PlaceholderFormat(sq.Dollar)
}

sql, _, err := query.ToSql()
sql, args, err := c.databaseQuerier.Constraints(tableName)
if err != nil {
return nil, nil, err
}

return c.Query(sql, tableName)
return c.Query(sql, args...)
}

// indexes returns a resulset with the information of the indexes given a table name.
func (c *Client) indexes(tableName string) ([][]string, []string, error) {
var query string

switch c.driver {
case drivers.Postgres:
fallthrough
case drivers.PostgreSQL:
query = "SELECT * FROM pg_indexes WHERE tablename = $1;"
return c.Query(query, tableName)
case drivers.MySQL:
query = fmt.Sprintf("SHOW INDEX FROM %s", tableName)
return c.Query(query)
case drivers.SQLite:
query = `PRAGMA index_list(%s);`
query = fmt.Sprintf(query, tableName)
return c.Query(query)
default:
return nil, nil, errors.New("not supported driver")
query, args, err := c.databaseQuerier.Indexes(tableName)
if err != nil {
return nil, nil, err
}

return c.Query(query, args...)
}
7 changes: 2 additions & 5 deletions pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@ import (
"os"
"testing"

// mysql driver.
_ "github.com/go-sql-driver/mysql"
// postgres driver.
_ "github.com/lib/pq"
// sqlite driver.
_ "modernc.org/sqlite"

"github.com/stretchr/testify/assert"
_ "modernc.org/sqlite"

"github.com/danvergara/dblab/pkg/command"
"github.com/danvergara/dblab/pkg/drivers"
Expand Down Expand Up @@ -334,6 +330,7 @@ func TestMetadata(t *testing.T) {
m, err := c.Metadata("products")

assert.NoError(t, err)
assert.NotNil(t, m)

// Total count.
assert.Equal(t, m.TotalPages, 1)
Expand Down
56 changes: 56 additions & 0 deletions pkg/client/mysql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package client

import (
"fmt"

sq "github.com/Masterminds/squirrel"
)

// mysql struct is in charge of perform all the mysql related queries,
// without the client knowing.
type mysql struct{}

// a validation to see if mysql is implementing databaseQuerier.
var _ databaseQuerier = (*mysql)(nil)

// returns a pointer to a mysql.
func newMySQL() *mysql {
m := mysql{}
return &m
}

// ShowTables returns a query to retrieve all the tables.
func (m *mysql) ShowTables() (string, []interface{}, error) {
query := "SHOW TABLES;"
return query, nil, nil
}

// TableStructure returns a query string to retrieve all the relevant information of a given table.
func (m *mysql) TableStructure(tableName string) (string, []interface{}, error) {
query := fmt.Sprintf("DESCRIBE %s;", tableName)
return query, nil, nil
}

// Constraints returns all the constraints of a given table.
func (m *mysql) Constraints(tableName string) (string, []interface{}, error) {
query := sq.Select(
`tc.constraint_name`,
`tc.table_name`,
`tc.constraint_type`,
).
From("information_schema.table_constraints AS tc").
Where("tc.table_name = ?", tableName)

sql, args, err := query.ToSql()
if err != nil {
return "", nil, err
}

return sql, args, err
}

// Indexes returns a query to get all the indexes of a table.
func (m *mysql) Indexes(tableName string) (string, []interface{}, error) {
query := fmt.Sprintf("SHOW INDEX FROM %s", tableName)
return query, nil, nil
}
Loading

0 comments on commit b1eefd1

Please sign in to comment.