diff --git a/examples/chatroom/main.go b/examples/chatroom/main.go index 2eab51d0..ef00f904 100644 --- a/examples/chatroom/main.go +++ b/examples/chatroom/main.go @@ -23,8 +23,8 @@ func main() { var upgrader = gws.NewUpgrader(handler, &gws.ServerOption{ PermessageDeflate: gws.PermessageDeflate{ Enabled: true, - ServerContextTakeover: false, - ClientContextTakeover: false, + ServerContextTakeover: true, + ClientContextTakeover: true, }, // 在querystring里面传入用户名 @@ -36,7 +36,7 @@ func main() { return false } session.Store("name", name) - session.Store("key", r.Header.Get("Sec-WebSocket-Key")) + session.Store("websocketKey", r.Header.Get("Sec-WebSocket-Key")) return true }, }) @@ -59,42 +59,46 @@ func main() { } } +func MustLoad[T any](session gws.SessionStorage, key string) (v T) { + if value, exist := session.Load(key); exist { + v = value.(T) + } + return +} + func NewWebSocket() *WebSocket { - return &WebSocket{sessions: gws.NewConcurrentMap[string, *gws.Conn](16)} + return &WebSocket{ + sessions: gws.NewConcurrentMap[string, *gws.Conn](16, 128), + } } type WebSocket struct { sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突 } -func (c *WebSocket) getName(socket *gws.Conn) string { - name, _ := socket.Session().Load("name") - return name.(string) -} - -func (c *WebSocket) getKey(socket *gws.Conn) string { - name, _ := socket.Session().Load("key") - return name.(string) -} - func (c *WebSocket) OnOpen(socket *gws.Conn) { - name := c.getName(socket) + name := MustLoad[string](socket.Session(), "name") if conn, ok := c.sessions.Load(name); ok { - conn.WriteClose(1000, []byte("connection replaced")) + conn.WriteClose(1000, []byte("connection is replaced")) } - socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout)) + _ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout)) c.sessions.Store(name, socket) log.Printf("%s connected\n", name) } func (c *WebSocket) OnClose(socket *gws.Conn, err error) { - name := c.getName(socket) - key := c.getKey(socket) - if mSocket, ok := c.sessions.Load(name); ok { - if mKey := c.getKey(mSocket); mKey == key { - c.sessions.Delete(name) + name := MustLoad[string](socket.Session(), "name") + sharding := c.sessions.GetSharding(name) + sharding.Lock() + defer sharding.Unlock() + + if conn, ok := sharding.Load(name); ok { + key0 := MustLoad[string](socket.Session(), "websocketKey") + if key1 := MustLoad[string](conn.Session(), "websocketKey"); key1 == key0 { + sharding.Delete(name) } } + log.Printf("onerror, name=%s, msg=%s\n", name, err.Error()) } @@ -114,14 +118,14 @@ func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) { defer message.Close() // chrome websocket不支持ping方法, 所以在text frame里面模拟ping - if b := message.Data.Bytes(); len(b) == 4 && string(b) == "ping" { + if b := message.Bytes(); len(b) == 4 && string(b) == "ping" { c.OnPing(socket, nil) return } var input = &Input{} - _ = json.Unmarshal(message.Data.Bytes(), input) + _ = json.Unmarshal(message.Bytes(), input) if conn, ok := c.sessions.Load(input.To); ok { - conn.WriteMessage(gws.OpcodeText, message.Data.Bytes()) + _ = conn.WriteMessage(gws.OpcodeText, message.Bytes()) } } diff --git a/session_storage.go b/session_storage.go index f7acaa5e..94999bde 100644 --- a/session_storage.go +++ b/session_storage.go @@ -62,78 +62,114 @@ type ( ConcurrentMap[K comparable, V any] struct { hasher maphash.Hasher[K] sharding uint64 - buckets []*bucket[K, V] - } - - bucket[K comparable, V any] struct { - sync.Mutex - m map[K]V + buckets []*Map[K, V] } ) -func NewConcurrentMap[K comparable, V any](sharding uint64) *ConcurrentMap[K, V] { - sharding = internal.SelectValue(sharding == 0, 16, sharding) +// NewConcurrentMap create a new concurrency-safe map +// arg0 represents the number of shardings; arg1 represents the initialized capacity of a sharding. +func NewConcurrentMap[K comparable, V any](size ...uint64) *ConcurrentMap[K, V] { + sharding, capacity := uint64(16), uint64(0) + if len(size) >= 1 { + sharding = size[0] + } + if len(size) >= 2 { + capacity = size[1] + } + sharding = internal.SelectValue(sharding <= 0, 16, sharding) sharding = internal.ToBinaryNumber(sharding) var cm = &ConcurrentMap[K, V]{ hasher: maphash.NewHasher[K](), sharding: sharding, - buckets: make([]*bucket[K, V], sharding), + buckets: make([]*Map[K, V], sharding), } for i, _ := range cm.buckets { - cm.buckets[i] = &bucket[K, V]{m: make(map[K]V)} + cm.buckets[i] = &Map[K, V]{m: make(map[K]V, capacity)} } return cm } -func (c *ConcurrentMap[K, V]) getBucket(key K) *bucket[K, V] { +// GetSharding returns a map sharding for a key +func (c *ConcurrentMap[K, V]) GetSharding(key K) *Map[K, V] { var hashCode = c.hasher.Hash(key) var index = hashCode & (c.sharding - 1) return c.buckets[index] } +// Len returns the number of elements in the map func (c *ConcurrentMap[K, V]) Len() int { var length = 0 for _, b := range c.buckets { b.Lock() - length += len(b.m) + length += b.Len() b.Unlock() } return length } -func (c *ConcurrentMap[K, V]) Load(key K) (value V, exist bool) { - var b = c.getBucket(key) +// Load returns the value stored in the map for a key, or nil if no +// value is present. +// The ok result indicates whether value was found in the map. +func (c *ConcurrentMap[K, V]) Load(key K) (value V, ok bool) { + var b = c.GetSharding(key) b.Lock() - value, exist = b.m[key] + value, ok = b.Load(key) b.Unlock() return } +// Delete deletes the value for a key. func (c *ConcurrentMap[K, V]) Delete(key K) { - var b = c.getBucket(key) + var b = c.GetSharding(key) b.Lock() - delete(b.m, key) + b.Delete(key) b.Unlock() } +// Store sets the value for a key. func (c *ConcurrentMap[K, V]) Store(key K, value V) { - var b = c.getBucket(key) + var b = c.GetSharding(key) b.Lock() - b.m[key] = value + b.Store(key, value) b.Unlock() } // Range calls f sequentially for each key and value present in the map. // If f returns false, range stops the iteration. func (c *ConcurrentMap[K, V]) Range(f func(key K, value V) bool) { - for _, b := range c.buckets { + var next = true + var cb = func(k K, v V) bool { + next = f(k, v) + return next + } + for i := uint64(0); i < c.sharding && next; i++ { + var b = c.buckets[i] b.Lock() - for k, v := range b.m { - if !f(k, v) { - b.Unlock() - return - } - } + b.Range(cb) b.Unlock() } } + +type Map[K comparable, V any] struct { + sync.Mutex + m map[K]V +} + +func (c *Map[K, V]) Len() int { return len(c.m) } + +func (c *Map[K, V]) Load(key K) (value V, ok bool) { + value, ok = c.m[key] + return +} + +func (c *Map[K, V]) Delete(key K) { delete(c.m, key) } + +func (c *Map[K, V]) Store(key K, value V) { c.m[key] = value } + +func (c *Map[K, V]) Range(f func(K, V) bool) { + for k, v := range c.m { + if !f(k, v) { + return + } + } +} diff --git a/session_storage_test.go b/session_storage_test.go index 3b215f07..8a6e66d7 100644 --- a/session_storage_test.go +++ b/session_storage_test.go @@ -123,6 +123,15 @@ func TestConcurrentMap(t *testing.T) { as.Equal(v, v1) } as.Equal(len(m1), m2.Len()) + + t.Run("", func(t *testing.T) { + var sum = 0 + var cm = NewConcurrentMap[string, int](8, 8) + for _, item := range cm.buckets { + sum += len(item.m) + } + assert.Equal(t, sum, 0) + }) } func TestConcurrentMap_Range(t *testing.T) {