Skip to content

Commit

Permalink
Refactor SmartUnmarshal (#90)
Browse files Browse the repository at this point in the history
Refactoring of SmartUnmarshal and code cleanup
  • Loading branch information
ElecTwix authored Aug 11, 2023
1 parent 731c716 commit 135edbe
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 246 deletions.
182 changes: 2 additions & 180 deletions db.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
package surrealdb

import (
"encoding/json"
"errors"
"fmt"

"reflect"

"github.com/surrealdb/surrealdb.go/pkg/constants"
"github.com/surrealdb/surrealdb.go/pkg/websocket"
)

const statusOK = "OK"

var (
InvalidResponse = errors.New("invalid SurrealDB response") //nolint:stylecheck
ErrQuery = errors.New("error occurred processing the SurrealDB query")
ErrNoRow = errors.New("error no row")
)

// DB is a client for the SurrealDB database that holds are websocket connection.
type DB struct {
ws websocket.WebSocket
Expand All @@ -28,159 +17,6 @@ func New(url string, ws websocket.WebSocket) (*DB, error) {
return &DB{ws}, nil
}

// Unmarshal loads a SurrealDB response into a struct.
func Unmarshal(data, v interface{}) error {
var jsonBytes []byte
var err error
if isSlice(v) {
assertedData, ok := data.([]interface{})
if !ok {
return fmt.Errorf("failed to deserialise response to slice: %w", InvalidResponse)
}
jsonBytes, err = json.Marshal(assertedData)
if err != nil {
return fmt.Errorf("failed to deserialise response '%+v' to slice: %w", assertedData, InvalidResponse)
}
} else {
jsonBytes, err = json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to deserialise response '%+v' to object: %w", data, err)
}
}
if err != nil {
return err
}

err = json.Unmarshal(jsonBytes, v)
if err != nil {
return fmt.Errorf("failed unmarshaling jsonBytes '%+v': %w", jsonBytes, err)
}
return nil
}

// UnmarshalRaw loads a raw SurrealQL response returned by Query into a struct. Queries that return with results will
// return ok = true, and queries that return with no results will return ok = false.
func UnmarshalRaw(rawData, v interface{}) (ok bool, err error) {
var data []interface{}
if data, ok = rawData.([]interface{}); !ok {
return false, fmt.Errorf("failed raw unmarshaling to interface slice: %w", InvalidResponse)
}

var responseObj map[string]interface{}
if responseObj, ok = data[0].(map[string]interface{}); !ok {
return false, fmt.Errorf("failed mapping to response object: %w", InvalidResponse)
}

var status string
if status, ok = responseObj["status"].(string); !ok {
return false, fmt.Errorf("failed retrieving status: %w", InvalidResponse)
}
if status != statusOK {
return false, fmt.Errorf("status was not ok: %w", ErrQuery)
}

result := responseObj["result"]
if len(result.([]interface{})) == 0 {
return false, nil
}
err = Unmarshal(result, v)
if err != nil {
return false, fmt.Errorf("failed to unmarshal: %w", err)
}

return true, nil
}

// Used for RawQuery Unmarshaling
type RawQuery[I any] struct {
Status string `json:"status"`
Time string `json:"time"`
Result I `json:"result"`
Detail string `json:"detail"`
}

// SmartUnmarshal using generics for return desired type.
// Supports both raw and normal queries.
func SmartUnmarshal[I any](respond interface{}, wrapperError error) (data I, err error) {
if wrapperError != nil {
return data, wrapperError
}
var bytes []byte
if arrResp, isArr := respond.([]interface{}); len(arrResp) > 0 {
if dataMap, ok := arrResp[0].(map[string]interface{}); ok && isArr {
if _, ok := dataMap["status"]; ok {
if bytes, err = json.Marshal(respond); err == nil {
var raw []RawQuery[I]
if err = json.Unmarshal(bytes, &raw); err == nil {
if raw[0].Status != statusOK {
err = fmt.Errorf("%s: %s", raw[0].Status, raw[0].Detail)
}
data = raw[0].Result
}
}
return data, err
}
}
}
if bytes, err = json.Marshal(respond); err == nil {
err = json.Unmarshal(bytes, &data)
}
return data, err
}

// Used for define table name, it has no value.
type Basemodel struct{}

// Smart Marshal Errors
var (
ErrNotStruct = errors.New("data is not struct")
ErrNotValidFunc = errors.New("invalid function")
)

// SmartUnmarshal can be used with all DB methods with generics and type safety.
// This handles errors and can use any struct tag with `BaseModel` type.
// Warning: "ID" field is case sensitive and expect string.
// Upon failure, the following will happen
// 1. If there are some ID on struct it will fill the table with the ID
// 2. If there are struct tags of the type `Basemodel`, it will use those values instead
// 3. If everything above fails or the IDs do not exist, SmartUnmarshal will use the struct name as the table name.
func SmartMarshal[I any](inputfunc interface{}, data I) (output interface{}, err error) {
var table string
datatype := reflect.TypeOf(data)
datavalue := reflect.ValueOf(data)
if datatype.Kind() == reflect.Pointer {
datatype = datatype.Elem()
datavalue = datavalue.Elem()
}
if datatype.Kind() == reflect.Struct {
if _, ok := datavalue.Field(0).Interface().(Basemodel); ok {
if temptable, ok := datatype.Field(0).Tag.Lookup("table"); ok {
table = temptable
} else {
table = reflect.TypeOf(data).Name()
}
}
if id, ok := datatype.FieldByName("ID"); ok {
if id.Type.Kind() == reflect.String {
if str, ok := datavalue.FieldByName("ID").Interface().(string); ok {
if str != "" {
table = str
}
}
}
}
} else {
return nil, ErrNotStruct
}
if function, ok := inputfunc.(func(thing string, data interface{}) (interface{}, error)); ok {
return function(table, data)
}
if function, ok := inputfunc.(func(thing string) (interface{}, error)); ok {
return function(table)
}
return nil, ErrNotValidFunc
}

// --------------------------------------------------
// Public methods
// --------------------------------------------------
Expand Down Expand Up @@ -308,21 +144,7 @@ func (db *DB) send(method string, params ...interface{}) (interface{}, error) {
// resp is a helper method for parsing the response from a query.
func (db *DB) resp(_ string, _ []interface{}, res interface{}) (interface{}, error) {
if res == nil {
return nil, ErrNoRow
return nil, constants.ErrNoRow
}
return res, nil
}

func isSlice(possibleSlice interface{}) bool {
slice := false

switch v := possibleSlice.(type) { //nolint:gocritic
default:
res := fmt.Sprintf("%s", v)
if res == "[]" || res == "&[]" || res == "*[]" {
slice = true
}
}

return slice
}
Loading

0 comments on commit 135edbe

Please sign in to comment.