Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Databases tree view #245

Merged
merged 6 commits into from
Nov 19, 2024
Merged
269 changes: 169 additions & 100 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package client

import (
"fmt"
"net/url"
"strings"

_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
Expand All @@ -26,6 +28,8 @@ type databaseQuerier interface {
TableStructure(tableName string) (string, []interface{}, error)
Constraints(tableName string) (string, []interface{}, error)
Indexes(tableName string) (string, []interface{}, error)
ShowDatabases() (string, []interface{}, error)
ShowTablesPerDB(database string) (string, []interface{}, error)
}

// Client is used to store the pool of db connection.
Expand All @@ -34,7 +38,10 @@ type Client struct {
databaseQuerier databaseQuerier
driver, schema string
paginationManager *pagination.Manager
activeDatabase string
limit uint
showDataCatalog bool
dbs map[string]*sqlx.DB
}

// New return an instance of the client.
Expand All @@ -49,10 +56,11 @@ func New(opts command.Options) (*Client, error) {
return nil, err
}

c := Client{
c := &Client{
db: db,
driver: opts.Driver,
limit: opts.Limit,
dbs: make(map[string]*sqlx.DB),
}

if opts.Schema == "" {
Expand All @@ -77,10 +85,28 @@ func New(opts command.Options) (*Client, error) {
return nil, fmt.Errorf("%s driver not supported", c.driver)
}

if opts.DBName == "" {
switch c.driver {
case drivers.PostgreSQL, drivers.Postgres, drivers.MySQL:
c.showDataCatalog = true
dbs, err := c.ShowDatabases()
if err != nil {
return nil, err
}

for _, d := range dbs {
db, err := getDB(c.driver, conn, d)
if err != nil {
continue
}

c.dbs[d] = db
}
}
}

switch c.driver {
case drivers.Postgres:
fallthrough
case drivers.PostgreSQL:
case drivers.PostgreSQL, drivers.Postgres:
if _, err = db.Exec(fmt.Sprintf("set search_path='%s'", c.schema)); err != nil {
return nil, err
}
Expand All @@ -93,15 +119,56 @@ func New(opts command.Options) (*Client, error) {

c.paginationManager = pm

return &c, nil
return c, nil
}

func (c *Client) SetActiveDatabase(database string) {
c.activeDatabase = database
}

func (c *Client) ActiveDatabase() string {
return c.activeDatabase
}

// DB Return the db attribute.
func (c *Client) DB() *sqlx.DB {
return c.db
}

// Driver returns the driver of the database.
func (c *Client) Driver() string {
return c.driver
}

func (c *Client) ShowDataCatalog() bool {
return c.showDataCatalog
}

// Query returns performs the query and returns the result set and the column names.
func (c *Client) Query(q string, args ...interface{}) ([][]string, []string, error) {
resultSet := [][]string{}
var (
resultSet = [][]string{}
db *sqlx.DB
ok bool
)

db = c.db

if c.activeDatabase != "" {
switch c.driver {
case drivers.Postgres, drivers.PostgreSQL, drivers.MySQL:
db, ok = c.dbs[c.activeDatabase]
if !ok {
return nil, nil, fmt.Errorf(
"connection with %s database not found",
c.activeDatabase,
)
}
}
}

// Runs the query extracting the content of the view calling the Buffer method.
rows, err := c.db.Queryx(q, args...)
rows, err := db.Queryx(q, args...)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -134,6 +201,9 @@ func (c *Client) Query(q string, args ...interface{}) ([][]string, []string, err

resultSet = append(resultSet, s)
}
if err := rows.Err(); err != nil {
return nil, nil, err
}

return resultSet, columnNames, nil
}
Expand All @@ -160,20 +230,6 @@ type Metadata struct {

// Metadata returns the most relevant data from a given table.
func (c *Client) Metadata(tableName string) (*Metadata, error) {
count, err := c.tableCount(tableName)
if err != nil {
return nil, err
}

pm, err := pagination.New(c.limit, count, tableName)
if err != nil {
return nil, err
}

c.paginationManager = pm

pages := c.paginationManager.TotalPages()

tcRows, tcColumns, err := c.tableContent(tableName)
if err != nil {
return nil, err
Expand Down Expand Up @@ -211,36 +267,38 @@ func (c *Client) Metadata(tableName string) (*Metadata, error) {
Rows: iRows,
Columns: iColumns,
},
TotalPages: pages,
}

return &m, nil
}

func (c *Client) TotalPages() int {
if c.paginationManager != nil {
return c.paginationManager.TotalPages()
}

return 0
}

// ShowTables list all the tables in the database on the tables panel.
func (c *Client) ShowTables() ([]string, error) {
func (c *Client) ShowTablesPerDB(database string) ([]string, error) {
var (
query string
err error
args []interface{}
db *sqlx.DB
ok bool
)

tables := make([]string, 0)

query, args, err = c.databaseQuerier.ShowTables()
query, args, err = c.databaseQuerier.ShowTablesPerDB(database)
if err != nil {
return nil, err
}

rows, err := c.db.Queryx(query, args...)
switch c.driver {
case drivers.PostgreSQL, drivers.Postgres, drivers.MySQL:
db, ok = c.dbs[database]
if !ok {
return nil, fmt.Errorf("connection with %s database not found", database)
}
default:
db = c.db
}

rows, err := db.Queryx(query, args...)
if err != nil {
return nil, err
}
Expand All @@ -253,73 +311,81 @@ func (c *Client) ShowTables() ([]string, error) {

tables = append(tables, table)
}
if err := rows.Err(); err != nil {
return nil, err
}

return tables, nil
}

// NextPage returns the next page of the given table, based off the limit and the offsite.
func (c *Client) NextPage() (*Table, int, error) {
if err := c.paginationManager.NextPage(); err != nil {
return nil, 0, err
}
// ShowTables list all the tables in the database on the tables panel.
func (c *Client) ShowTables() ([]string, error) {
var (
query string
err error
args []interface{}
)

tables := make([]string, 0)

r, col, err := c.tableContent(c.paginationManager.CurrentTable())
query, args, err = c.databaseQuerier.ShowTables()
if err != nil {
return nil, 0, err
return nil, err
}

t := Table{
name: c.paginationManager.CurrentTable(),
Rows: r,
Columns: col,
rows, err := c.db.Queryx(query, args...)
if err != nil {
return nil, err
}

page := c.paginationManager.CurrentPage()

return &t, page, nil
}
for rows.Next() {
var table string
if err := rows.Scan(&table); err != nil {
return nil, err
}

// PreviousPage returns the next page of the given table, based off the limit and the offsite.
func (c *Client) PreviousPage() (*Table, int, error) {
if err := c.paginationManager.PreviousPage(); err != nil {
return nil, 0, err
tables = append(tables, table)
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each row.Nect loop must be followed by checking rows.Err()

Suggested change
}
}
if err:= rows.Err(); err != nil {
return nil, err
}


r, col, err := c.tableContent(c.paginationManager.CurrentTable())
if err != nil {
return nil, 0, err
if err := rows.Err(); err != nil {
return nil, err
}

t := Table{
name: c.paginationManager.CurrentTable(),
Rows: r,
Columns: col,
}
return tables, nil
}

page := c.paginationManager.CurrentPage()
// ShowDatabases returns a list of the databases the user has access to.
func (c *Client) ShowDatabases() ([]string, error) {
var (
query string
err error
args []interface{}
)

return &t, page, nil
}
databases := make([]string, 0)

// ResetPagination resets the paginationManager field.
func (c *Client) ResetPagination() error {
pm, err := pagination.New(c.limit, 0, "")
query, args, err = c.databaseQuerier.ShowDatabases()
if err != nil {
return err
return nil, err
}

c.paginationManager = pm
return nil
}
rows, err := c.db.Queryx(query, args...)
if err != nil {
return nil, err
}

// DB Return the db attribute.
func (c *Client) DB() *sqlx.DB {
return c.db
}
for rows.Next() {
var database string
if err := rows.Scan(&database); err != nil {
return nil, err
}

// Driver returns the driver of the database.
func (c *Client) Driver() string {
return c.driver
databases = append(databases, database)
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue about missing rows.Err check

if err := rows.Err(); err != nil {
return nil, err
}

return databases, nil
}

// TableContent returns all the rows of a table.
Expand Down Expand Up @@ -360,27 +426,6 @@ func (c *Client) tableContent(tableName string) ([][]string, []string, error) {
return c.Query(query)
}

// tableCount returns the count of a given table.
func (c *Client) tableCount(tableName string) (int, error) {
var (
query string
count int
)

switch c.driver {
case drivers.Postgres, drivers.PostgreSQL:
query = fmt.Sprintf("SELECT COUNT(*) FROM %q;", tableName)
default:
query = fmt.Sprintf("SELECT COUNT(*) FROM %s;", tableName)
}

if err := c.db.Get(&count, query); err != nil {
return 0, err
}

return count, nil
}

// tableStructure returns the structure of the table columns.
func (c *Client) tableStructure(tableName string) ([][]string, []string, error) {
var (
Expand Down Expand Up @@ -416,3 +461,27 @@ func (c *Client) indexes(tableName string) ([][]string, []string, error) {

return c.Query(query, args...)
}

func getDB(driver, connString, database string) (*sqlx.DB, error) {
var newConnString string

switch driver {
case drivers.MySQL:
newConnString = strings.Replace(connString, "/", fmt.Sprintf("/%s", database), 1)
default:
u, err := url.Parse(connString)
if err != nil {
return nil, err
}

u.Path = "/" + database
newConnString = u.String()
Comment on lines +471 to +478

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be safer (I might have forgotten some)

Suggested change
default:
u, err := url.Parse(connString)
if err != nil {
return nil, err
}
u.Path = "/" + database
newConnString = u.String()
case drivers.PostgreSQL, drivers.Postgres:
u, err := url.Parse(connString)
if err != nil {
return nil, err
}
u.Path = "/" + database
newConnString = u.String()
default:
return nil, fmt.Errorf("unsupported driver %s", driver)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree with this, since the driver validation is done somewhere else and the MySQL URL parsing is the corner case here The default behavior is parsing the DSN with url.Parse.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK then. I misunderstood

}

db, err := sqlx.Open(driver, newConnString)
if err != nil {
return nil, err
}

return db, nil
}
Loading