From b557479e8957bcd05b5b131fe794000bb862e5b3 Mon Sep 17 00:00:00 2001 From: Rob Gonnella Date: Fri, 23 Jun 2023 01:42:48 -0400 Subject: [PATCH] Fixes race conditions --- Makefile | 8 +++- internal/core/core.go | 10 +++++ internal/core/core_test.go | 39 +++++++++++++---- internal/core/monitor.go | 28 ++++++++---- internal/server/service.go | 18 +++++--- internal/server/service_test.go | 2 +- internal/ui/component/context.go | 16 ++++++- internal/ui/component/event.go | 7 --- internal/ui/component/server.go | 8 ---- internal/ui/launch.go | 8 +++- internal/ui/view.go | 75 ++++++++++++++++---------------- 11 files changed, 139 insertions(+), 80 deletions(-) diff --git a/Makefile b/Makefile index d040a9f..c75d83a 100644 --- a/Makefile +++ b/Makefile @@ -8,12 +8,18 @@ all: $(PREFIX)/ops $(PREFIX)/ops: $(go_deps) cd cli && go build -ldflags '-s -w' -o $(@) +$(PREFIX)/ops-dev: $(go_deps) + cd cli && go build -race -ldflags '-s -w' -o $(@) + .PHONY: ops ops: $(PREFIX)/ops +.PHONY: dev +dev: $(PREFIX)/ops-dev + .PHONY: test test: - go test -v ./... + go test -v -race ./... .PHONY: mock mock: diff --git a/internal/core/core.go b/internal/core/core.go index 5145239..a20cd64 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -30,6 +30,7 @@ type Core struct { discovery discovery.Service serverService server.Service log logger.Logger + eventSubscription int evtListeners []*EventListener serverPollListeners []*ServerPollListener nextListenerId int @@ -64,6 +65,9 @@ func New( func (c *Core) Stop() error { c.discovery.Stop() + if c.eventSubscription != 0 { + c.serverService.StopStream(c.eventSubscription) + } c.cancel() return c.ctx.Err() } @@ -146,6 +150,9 @@ func (c *Core) RemoveEventListener(id int) { listeners := []*EventListener{} for _, listener := range c.evtListeners { + if listener.id == id { + close(listener.channel) + } if listener.id != id { listeners = append(listeners, listener) } @@ -175,6 +182,9 @@ func (c *Core) RemoveServerPollListener(id int) { listeners := []*ServerPollListener{} for _, listener := range c.serverPollListeners { + if listener.id == id { + close(listener.channel) + } if listener.id != id { listeners = append(listeners, listener) } diff --git a/internal/core/core_test.go b/internal/core/core_test.go index 79ef390..a43e85c 100644 --- a/internal/core/core_test.go +++ b/internal/core/core_test.go @@ -1,8 +1,8 @@ package core_test import ( + "sync" "testing" - "time" "github.com/golang/mock/gomock" "github.com/robgonnella/ops/internal/config" @@ -159,7 +159,7 @@ func TestCore(t *testing.T) { }) t.Run("registers and removes event listener", func(st *testing.T) { - evtChan := make(chan *event.Event, 1) + evtChan := make(chan *event.Event) id := coreService.RegisterEventListener(evtChan) assert.Equal(st, 1, id) @@ -168,7 +168,7 @@ func TestCore(t *testing.T) { }) t.Run("registers and removes server listener", func(st *testing.T) { - serverChan := make(chan []*server.Server, 1) + serverChan := make(chan []*server.Server) id := coreService.RegisterServerPollListener(serverChan) @@ -178,16 +178,39 @@ func TestCore(t *testing.T) { }) t.Run("monitors network", func(st *testing.T) { - defer coreService.Stop() + wg := sync.WaitGroup{} + wg.Add(2) - mockServerService.EXPECT().StreamEvents(gomock.Any()) + mockServerService.EXPECT().StreamEvents(gomock.Any()).Return(1) mockServerService.EXPECT().GetAllServersInNetworkTargets(conf.Targets) - mockScanner.EXPECT().Scan() + mockScanner.EXPECT().Scan().DoAndReturn(func() ([]*discovery.DiscoveryResult, error) { + defer func() { + coreService.Stop() + wg.Done() + }() + return []*discovery.DiscoveryResult{ + { + ID: "id", + Hostname: "hostname", + IP: "ip", + OS: "os", + Status: server.StatusOnline, + Ports: []discovery.Port{ + { + ID: 22, + Status: discovery.PortOpen, + }, + }, + }, + }, nil + }) mockScanner.EXPECT().Stop() - mockServerService.EXPECT().StopStream(gomock.Any()).AnyTimes() + mockServerService.EXPECT().StopStream(1).Do(func(int) { + wg.Done() + }) go coreService.Monitor() - time.Sleep(time.Millisecond * 10) + wg.Wait() }) } diff --git a/internal/core/monitor.go b/internal/core/monitor.go index ca9933a..0a7f484 100644 --- a/internal/core/monitor.go +++ b/internal/core/monitor.go @@ -10,12 +10,10 @@ import ( // Run runs the sequence driver for the HostInstallStage func (c *Core) Monitor() error { - evtReceiveChan := make(chan *event.Event, 100) + evtReceiveChan := make(chan *event.Event) // create event subscription - subscription := c.serverService.StreamEvents(evtReceiveChan) - - defer c.serverService.StopStream(subscription) + c.eventSubscription = c.serverService.StreamEvents(evtReceiveChan) // Start network scanner go c.discovery.MonitorNetwork() @@ -24,12 +22,16 @@ func (c *Core) Monitor() error { go c.pollForDatabaseUpdates() for { - select { - case <-c.ctx.Done(): - return c.ctx.Err() - case evt := <-evtReceiveChan: - c.handleServerEvent(evt) + evt, ok := <-evtReceiveChan + if !ok { + c.mux.Lock() + for _, listener := range c.evtListeners { + close(listener.channel) + } + c.mux.Unlock() + return nil } + c.handleServerEvent(evt) } } @@ -47,6 +49,9 @@ func (c *Core) handleServerEvent(evt *event.Event) { c.log.Info().Fields(fields).Msg("Event Received") + c.mux.Lock() + defer c.mux.Unlock() + for _, listener := range c.evtListeners { listener.channel <- evt } @@ -59,6 +64,9 @@ func (c *Core) pollForDatabaseUpdates() error { for { select { case <-c.ctx.Done(): + for _, listener := range c.serverPollListeners { + close(listener.channel) + } return c.ctx.Err() default: if errCount >= 5 { @@ -77,9 +85,11 @@ func (c *Core) pollForDatabaseUpdates() error { errCount = 0 + c.mux.Lock() for _, listener := range c.serverPollListeners { listener.channel <- response } + c.mux.Unlock() time.Sleep(pollTime) } diff --git a/internal/server/service.go b/internal/server/service.go index 4742e05..6b9bef2 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -1,9 +1,9 @@ package server import ( - "context" "errors" "net" + "sync" "github.com/imdario/mergo" "github.com/robgonnella/ops/internal/config" @@ -37,23 +37,21 @@ func filterChannels(channels []*eventChannel, fn func(c *eventChannel) bool) []* // ServerService represents our server service implementation type ServerService struct { - ctx context.Context log logger.Logger repo Repo evtChans []*eventChannel + mux sync.Mutex } // NewService returns a new instance server service func NewService(conf config.Config, repo Repo) *ServerService { log := logger.New() - ctx := context.Background() - return &ServerService{ - ctx: ctx, log: log, repo: repo, evtChans: []*eventChannel{}, + mux: sync.Mutex{}, } } @@ -175,14 +173,22 @@ func (s *ServerService) StreamEvents(send chan *event.Event) int { send: send, } + s.mux.Lock() s.evtChans = append(s.evtChans, evtChan) + s.mux.Unlock() return evtChan.id } func (s *ServerService) StopStream(id int) { + s.mux.Lock() + defer s.mux.Unlock() + s.log.Info().Int("channelID", id).Msg("Filtering channel") s.evtChans = filterChannels(s.evtChans, func(c *eventChannel) bool { + if c.id == id { + close(c.send) + } return c.id != id }) } @@ -193,6 +199,8 @@ func (s *ServerService) GetServer(id string) (*Server, error) { } func (s *ServerService) sendServerUpdateEvent(server *Server) { + s.mux.Lock() + defer s.mux.Unlock() for _, clientChan := range s.evtChans { clientChan.send <- &event.Event{ Type: event.SeverUpdate, diff --git a/internal/server/service_test.go b/internal/server/service_test.go index 4622ce1..99c741d 100644 --- a/internal/server/service_test.go +++ b/internal/server/service_test.go @@ -127,7 +127,7 @@ func TestServerService(t *testing.T) { }) t.Run("streams events", func(st *testing.T) { - evtChan := make(chan *event.Event, 1) + evtChan := make(chan *event.Event) streamID := service.StreamEvents(evtChan) diff --git a/internal/ui/component/context.go b/internal/ui/component/context.go index 4437c21..12944f8 100644 --- a/internal/ui/component/context.go +++ b/internal/ui/component/context.go @@ -7,6 +7,7 @@ import ( "github.com/gdamore/tcell/v2" "github.com/rivo/tview" "github.com/robgonnella/ops/internal/config" + "github.com/robgonnella/ops/internal/logger" "github.com/robgonnella/ops/internal/ui/key" "github.com/robgonnella/ops/internal/ui/style" ) @@ -21,6 +22,8 @@ func NewConfigContext( onSelect func(id int), onDelete func(name string, id int), ) *ConfigContext { + log := logger.New() + colHeaders := []string{"ID", "Name", "Target", "SSH-User", "SSH-Identity", "Overrides"} table := createTable("Context", colHeaders) @@ -31,7 +34,12 @@ func NewConfigContext( idStr := table.GetCell(row, 0).Text name := table.GetCell(row, 1).Text - id, _ := strconv.Atoi(idStr) + id, err := strconv.Atoi(idStr) + + if err != nil { + log.Error().Err(err).Msg("failed to delete context") + return nil + } onDelete(name, id) @@ -41,7 +49,11 @@ func NewConfigContext( if evt.Key() == key.KeyEnter { row, _ := table.GetSelection() idStr := table.GetCell(row, 0).Text - id, _ := strconv.Atoi(idStr) + id, err := strconv.Atoi(idStr) + if err != nil { + log.Error().Err(err).Msg("failed to select new context") + return nil + } onSelect(id) return nil } diff --git a/internal/ui/component/event.go b/internal/ui/component/event.go index b4a22ac..fcf3b3f 100644 --- a/internal/ui/component/event.go +++ b/internal/ui/component/event.go @@ -1,7 +1,6 @@ package component import ( - "context" "strconv" "github.com/rivo/tview" @@ -11,8 +10,6 @@ import ( ) type EventTable struct { - ctx context.Context - cancel context.CancelFunc table *tview.Table columnHeaders []string count uint @@ -31,11 +28,7 @@ func NewEventTable() *EventTable { "STATUS", } - ctx, cancel := context.WithCancel(context.Background()) - return &EventTable{ - ctx: ctx, - cancel: cancel, table: createTable("events", columnHeaders), columnHeaders: columnHeaders, count: 0, diff --git a/internal/ui/component/server.go b/internal/ui/component/server.go index 43a7b8c..e6bdad0 100644 --- a/internal/ui/component/server.go +++ b/internal/ui/component/server.go @@ -1,8 +1,6 @@ package component import ( - "context" - "github.com/gdamore/tcell/v2" "github.com/rivo/tview" "github.com/robgonnella/ops/internal/server" @@ -11,8 +9,6 @@ import ( ) type ServerTable struct { - ctx context.Context - cancel context.CancelFunc table *tview.Table columnHeaders []string } @@ -20,8 +16,6 @@ type ServerTable struct { func NewServerTable(OnSSH func(ip string)) *ServerTable { columnHeaders := []string{"HOSTNAME", "IP", "ID", "OS", "SSH", "STATUS"} - ctx, cancel := context.WithCancel(context.Background()) - table := createTable("servers", columnHeaders) table.SetInputCapture(func(evt *tcell.EventKey) *tcell.EventKey { @@ -36,8 +30,6 @@ func NewServerTable(OnSSH func(ip string)) *ServerTable { }) return &ServerTable{ - ctx: ctx, - cancel: cancel, table: table, columnHeaders: columnHeaders, } diff --git a/internal/ui/launch.go b/internal/ui/launch.go index bb0a233..e90ee5f 100644 --- a/internal/ui/launch.go +++ b/internal/ui/launch.go @@ -62,7 +62,13 @@ func (u *UI) Launch() error { log.Fatal().Err(err).Msg("failed to create app core") } - u.view = newView(*userIP, appCore) + allConfigs, err := appCore.GetConfigs() + + if err != nil { + log.Fatal().Err(err).Msg("failed to retrieve configs") + } + + u.view = newView(*userIP, allConfigs, appCore) os.Stdout, _ = os.Open(os.DevNull) os.Stderr, _ = os.Open(os.DevNull) diff --git a/internal/ui/view.go b/internal/ui/view.go index e77bf8c..7835dd9 100644 --- a/internal/ui/view.go +++ b/internal/ui/view.go @@ -1,7 +1,6 @@ package ui import ( - "context" "fmt" "os" "os/exec" @@ -30,8 +29,6 @@ func WithFocusedView(name string) ViewOption { } type view struct { - ctx context.Context - cancel context.CancelFunc app *tview.Application root *tview.Flex pages *tview.Pages @@ -53,7 +50,7 @@ type view struct { log logger.Logger } -func newView(userIP string, appCore *core.Core) *view { +func newView(userIP string, allConfigs []*config.Config, appCore *core.Core) *view { log := logger.New() v := &view{ @@ -61,21 +58,21 @@ func newView(userIP string, appCore *core.Core) *view { appCore: appCore, } - v.initialize(userIP) + v.initialize(userIP, allConfigs) return v } -func (v *view) initialize(userIP string, options ...ViewOption) { - v.ctx, v.cancel = context.WithCancel(context.Background()) - +func (v *view) initialize( + userIP string, + allConfigs []*config.Config, + options ...ViewOption, +) { v.viewNames = []string{"servers", "events", "context", "configure"} v.showingSwitchViewInput = false v.app = tview.NewApplication() - allConfigs, _ := v.appCore.GetConfigs() - v.root = tview.NewFlex().SetDirection(tview.FlexRow) v.pages = tview.NewPages() @@ -109,8 +106,8 @@ func (v *view) initialize(userIP string, options ...ViewOption) { AddItem(v.header.Primitive(), 12, 1, false). AddItem(v.pages, 0, 1, true) - v.serverUpdateChan = make(chan []*server.Server, 100) - v.eventUpdateChan = make(chan *event.Event, 100) + v.serverUpdateChan = make(chan []*server.Server) + v.eventUpdateChan = make(chan *event.Event) v.focusedName = "servers" for _, o := range options { @@ -360,23 +357,22 @@ func (v *view) onSSH(ip string) { func (v *view) processBackgroundServerUpdates() { go func() { for { - select { - case <-v.ctx.Done(): + servers, ok := <-v.serverUpdateChan + if !ok { return - case servers := <-v.serverUpdateChan: - v.app.QueueUpdateDraw(func() { - sort.Slice(servers, func(i, j int) bool { - if servers[i].Hostname == "unknown" { - return false - } - if servers[j].Hostname == "unknown" { - return true - } - return servers[i].Hostname < servers[j].Hostname - }) - v.serverTable.UpdateTable(servers) - }) } + v.app.QueueUpdateDraw(func() { + sort.Slice(servers, func(i, j int) bool { + if servers[i].Hostname == "unknown" { + return false + } + if servers[j].Hostname == "unknown" { + return true + } + return servers[i].Hostname < servers[j].Hostname + }) + v.serverTable.UpdateTable(servers) + }) } }() } @@ -384,14 +380,13 @@ func (v *view) processBackgroundServerUpdates() { func (v *view) processBackgroundEventUpdates() { go func() { for { - select { - case <-v.ctx.Done(): + evt, ok := <-v.eventUpdateChan + if !ok { return - case evt := <-v.eventUpdateChan: - v.app.QueueUpdateDraw(func() { - v.eventTable.UpdateTable(evt) - }) } + v.app.QueueUpdateDraw(func() { + v.eventTable.UpdateTable(evt) + }) } }() } @@ -416,11 +411,8 @@ func (v *view) stop() { v.appCore.RemoveEventListener(v.eventListenerID) v.serverPollListenerID = 0 v.eventListenerID = 0 - v.cancel() v.appCore.Stop() v.app.Stop() - v.ctx = nil - v.cancel = nil } func (v *view) restart(options ...ViewOption) { @@ -441,8 +433,15 @@ func (v *view) restart(options ...ViewOption) { } v.appCore = appCore - v.ctx, v.cancel = context.WithCancel(context.Background()) - v.initialize(*userIP, options...) + + allConfigs, err := v.appCore.GetConfigs() + + if err != nil { + restoreStdout() + v.log.Fatal().Err(err).Msg("failed to retrieve configs") + } + + v.initialize(*userIP, allConfigs, options...) if err := v.run(); err != nil { restoreStdout()