diff --git a/keysign/signature_notifier_test.go b/keysign/signature_notifier_test.go index 130f6e0..c8ead8f 100644 --- a/keysign/signature_notifier_test.go +++ b/keysign/signature_notifier_test.go @@ -155,10 +155,16 @@ func TestSignatureNotifierBroadcastFirst(t *testing.T) { p1, p2, })) - n1.notifierLock.Lock() - assert.Contains(t, n1.notifiers, messageID) - notifier := n1.notifiers[messageID] - n1.notifierLock.Unlock() + var notifier *notifier + var ok bool + assert.Eventually(t, func() bool { + n1.notifierLock.Lock() + defer n1.notifierLock.Unlock() + + notifier, ok = n1.notifiers[messageID] + return ok + }, time.Second, time.Millisecond*100) + assert.False(t, notifier.readyToProcess()) assert.Equal(t, defaultNotifierTTL, notifier.ttl) diff --git a/p2p/communication.go b/p2p/communication.go index 81ab547..cdf94a6 100644 --- a/p2p/communication.go +++ b/p2p/communication.go @@ -173,7 +173,7 @@ func (c *Communication) readFromStream(stream network.Stream) { return } c.logger.Debug().Msgf(">>>>>>>[%s] %s", wrappedMsg.MessageType, string(wrappedMsg.Payload)) - channel := c.getSubscriber(wrappedMsg.MessageType, wrappedMsg.MsgID) + channel := c.getSubscriberWithRetry(wrappedMsg.MessageType, wrappedMsg.MsgID) if nil == channel { c.logger.Debug().Msgf("no MsgID %s found for this message", wrappedMsg.MsgID) c.logger.Debug().Msgf("no MsgID %s found for this message", wrappedMsg.MessageType) @@ -445,6 +445,19 @@ func (c *Communication) getSubscriber(topic messages.THORChainTSSMessageType, ms return messageIDSubscribers.GetSubscriber(msgID) } +// getSubscriberWithRetry tries to get a subscriber a few times to avoid race conditions +func (c *Communication) getSubscriberWithRetry(topic messages.THORChainTSSMessageType, msgID string) chan *Message { + var res chan *Message + for i := 0; i < 3; i++ { + res = c.getSubscriber(topic, msgID) + if res != nil { + return res + } + time.Sleep(time.Millisecond * 50) + } + return res +} + func (c *Communication) CancelSubscribe(topic messages.THORChainTSSMessageType, msgID string) { c.subscriberLocker.Lock() defer c.subscriberLocker.Unlock()