From 84488155fd57c903317752ff4079a7ba2db40256 Mon Sep 17 00:00:00 2001 From: Ryan Adam <88693529+TheRangiCrew@users.noreply.github.com> Date: Fri, 1 Nov 2024 01:21:39 +1300 Subject: [PATCH] Handle RPC errors (#168) --- db_test.go | 12 ++++++++++++ pkg/connection/connection.go | 30 ++++++++++++++++++++++++++++++ pkg/connection/ws.go | 22 ++++++++++++++++++++++ 3 files changed, 64 insertions(+) diff --git a/db_test.go b/db_test.go index df64152..37f1b76 100644 --- a/db_test.go +++ b/db_test.go @@ -464,3 +464,15 @@ func (s *SurrealDBTestSuite) TestQueryRaw() { fmt.Println(created) fmt.Println(selected) } + +func (s *SurrealDBTestSuite) TestRPCError() { + s.Run("Test valid query", func() { + _, err := surrealdb.Query[[]testUser](s.db, "SELECT * FROM users", map[string]interface{}{}) + s.Require().NoError(err) + }) + + s.Run("Test invalid query", func() { + _, err := surrealdb.Query[[]testUser](s.db, "SELEC * FROM users", map[string]interface{}{}) + s.Require().Error(err) + }) +} diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go index 29860bb..1a65928 100644 --- a/pkg/connection/connection.go +++ b/pkg/connection/connection.go @@ -42,6 +42,9 @@ type BaseConnection struct { responseChannels map[string]chan []byte responseChannelsLock sync.RWMutex + errorChannels map[string]chan error + errorChannelsLock sync.RWMutex + notificationChannels map[string]chan Notification notificationChannelsLock sync.RWMutex } @@ -60,6 +63,20 @@ func (bc *BaseConnection) createResponseChannel(id string) (chan []byte, error) return ch, nil } +func (bc *BaseConnection) createErrorChannel(id string) (chan error, error) { + bc.errorChannelsLock.Lock() + defer bc.errorChannelsLock.Unlock() + + if _, ok := bc.errorChannels[id]; ok { + return nil, fmt.Errorf("%w: %v", constants.ErrIDInUse, id) + } + + ch := make(chan error) + bc.errorChannels[id] = ch + + return ch, nil +} + func (bc *BaseConnection) createNotificationChannel(liveQueryID string) (chan Notification, error) { bc.notificationChannelsLock.Lock() defer bc.notificationChannelsLock.Unlock() @@ -80,6 +97,12 @@ func (bc *BaseConnection) removeResponseChannel(id string) { delete(bc.responseChannels, id) } +func (bc *BaseConnection) removeErrorChannel(id string) { + bc.errorChannelsLock.Lock() + defer bc.errorChannelsLock.Unlock() + delete(bc.errorChannels, id) +} + func (bc *BaseConnection) getResponseChannel(id string) (chan []byte, bool) { bc.responseChannelsLock.RLock() defer bc.responseChannelsLock.RUnlock() @@ -87,6 +110,13 @@ func (bc *BaseConnection) getResponseChannel(id string) (chan []byte, bool) { return ch, ok } +func (bc *BaseConnection) getErrorChannel(id string) (chan error, bool) { + bc.errorChannelsLock.RLock() + defer bc.errorChannelsLock.RUnlock() + ch, ok := bc.errorChannels[id] + return ch, ok +} + func (bc *BaseConnection) getLiveChannel(id string) (chan Notification, bool) { bc.notificationChannelsLock.RLock() defer bc.notificationChannelsLock.RUnlock() diff --git a/pkg/connection/ws.go b/pkg/connection/ws.go index d0f975f..df4b474 100644 --- a/pkg/connection/ws.go +++ b/pkg/connection/ws.go @@ -44,6 +44,7 @@ func NewWebSocketConnection(p NewConnectionParams) *WebSocketConnection { unmarshaler: p.Unmarshaler, responseChannels: make(map[string]chan []byte), + errorChannels: make(map[string]chan error), notificationChannels: make(map[string]chan Notification), }, @@ -159,7 +160,12 @@ func (ws *WebSocketConnection) Send(dest interface{}, method string, params ...i if err != nil { return err } + errorChan, err := ws.createErrorChannel(id) + if err != nil { + return err + } defer ws.removeResponseChannel(id) + defer ws.removeErrorChannel(id) if err := ws.write(request); err != nil { return err @@ -177,6 +183,11 @@ func (ws *WebSocketConnection) Send(dest interface{}, method string, params ...i return ws.unmarshaler.Unmarshal(resBytes, dest) } return nil + case resErr, open := <-errorChan: + if !open { + return errors.New("error channel closed") + } + return resErr } } @@ -234,6 +245,17 @@ func (ws *WebSocketConnection) handleResponse(res []byte) { if rpcRes.Error != nil { err := fmt.Errorf("rpc request err %w", rpcRes.Error) ws.logger.Error(err.Error()) + + errChan, ok := ws.getErrorChannel(fmt.Sprintf("%v", rpcRes.ID)) + if !ok { + err := fmt.Errorf("unavailable ErrorChannel %+v", rpcRes.ID) + ws.logger.Error(err.Error()) + return + } + + defer close(errChan) + errChan <- rpcRes.Error + return }