Skip to content

Commit

Permalink
export GetSharding method
Browse files Browse the repository at this point in the history
  • Loading branch information
lixizan committed Jun 7, 2024
1 parent 2042b4f commit cfdae56
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 51 deletions.
54 changes: 29 additions & 25 deletions examples/chatroom/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ func main() {
var upgrader = gws.NewUpgrader(handler, &gws.ServerOption{
PermessageDeflate: gws.PermessageDeflate{
Enabled: true,
ServerContextTakeover: false,
ClientContextTakeover: false,
ServerContextTakeover: true,
ClientContextTakeover: true,
},

// 在querystring里面传入用户名
Expand All @@ -36,7 +36,7 @@ func main() {
return false
}
session.Store("name", name)
session.Store("key", r.Header.Get("Sec-WebSocket-Key"))
session.Store("websocketKey", r.Header.Get("Sec-WebSocket-Key"))
return true
},
})
Expand All @@ -59,42 +59,46 @@ func main() {
}
}

func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
if value, exist := session.Load(key); exist {
v = value.(T)
}
return
}

func NewWebSocket() *WebSocket {
return &WebSocket{sessions: gws.NewConcurrentMap[string, *gws.Conn](16)}
return &WebSocket{
sessions: gws.NewConcurrentMap[string, *gws.Conn](16, 128),
}
}

type WebSocket struct {
sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
}

func (c *WebSocket) getName(socket *gws.Conn) string {
name, _ := socket.Session().Load("name")
return name.(string)
}

func (c *WebSocket) getKey(socket *gws.Conn) string {
name, _ := socket.Session().Load("key")
return name.(string)
}

func (c *WebSocket) OnOpen(socket *gws.Conn) {
name := c.getName(socket)
name := MustLoad[string](socket.Session(), "name")
if conn, ok := c.sessions.Load(name); ok {
conn.WriteClose(1000, []byte("connection replaced"))
conn.WriteClose(1000, []byte("connection is replaced"))
}
socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
_ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
c.sessions.Store(name, socket)
log.Printf("%s connected\n", name)
}

func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
name := c.getName(socket)
key := c.getKey(socket)
if mSocket, ok := c.sessions.Load(name); ok {
if mKey := c.getKey(mSocket); mKey == key {
c.sessions.Delete(name)
name := MustLoad[string](socket.Session(), "name")
sharding := c.sessions.GetSharding(name)
sharding.Lock()
defer sharding.Unlock()

if conn, ok := sharding.Load(name); ok {
key0 := MustLoad[string](socket.Session(), "websocketKey")
if key1 := MustLoad[string](conn.Session(), "websocketKey"); key1 == key0 {
sharding.Delete(name)
}
}

log.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
}

Expand All @@ -114,14 +118,14 @@ func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
defer message.Close()

// chrome websocket不支持ping方法, 所以在text frame里面模拟ping
if b := message.Data.Bytes(); len(b) == 4 && string(b) == "ping" {
if b := message.Bytes(); len(b) == 4 && string(b) == "ping" {
c.OnPing(socket, nil)
return
}

var input = &Input{}
_ = json.Unmarshal(message.Data.Bytes(), input)
_ = json.Unmarshal(message.Bytes(), input)
if conn, ok := c.sessions.Load(input.To); ok {
conn.WriteMessage(gws.OpcodeText, message.Data.Bytes())
_ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
}
}
88 changes: 62 additions & 26 deletions session_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,78 +62,114 @@ type (
ConcurrentMap[K comparable, V any] struct {
hasher maphash.Hasher[K]
sharding uint64
buckets []*bucket[K, V]
}

bucket[K comparable, V any] struct {
sync.Mutex
m map[K]V
buckets []*Map[K, V]
}
)

func NewConcurrentMap[K comparable, V any](sharding uint64) *ConcurrentMap[K, V] {
sharding = internal.SelectValue(sharding == 0, 16, sharding)
// NewConcurrentMap create a new concurrency-safe map
// arg0 represents the number of shardings; arg1 represents the initialized capacity of a sharding.
func NewConcurrentMap[K comparable, V any](size ...uint64) *ConcurrentMap[K, V] {
sharding, capacity := uint64(16), uint64(0)
if len(size) >= 1 {
sharding = size[0]
}
if len(size) >= 2 {
capacity = size[1]
}
sharding = internal.SelectValue(sharding <= 0, 16, sharding)
sharding = internal.ToBinaryNumber(sharding)
var cm = &ConcurrentMap[K, V]{
hasher: maphash.NewHasher[K](),
sharding: sharding,
buckets: make([]*bucket[K, V], sharding),
buckets: make([]*Map[K, V], sharding),
}
for i, _ := range cm.buckets {
cm.buckets[i] = &bucket[K, V]{m: make(map[K]V)}
cm.buckets[i] = &Map[K, V]{m: make(map[K]V, capacity)}
}
return cm
}

func (c *ConcurrentMap[K, V]) getBucket(key K) *bucket[K, V] {
// GetSharding returns a map sharding for a key
func (c *ConcurrentMap[K, V]) GetSharding(key K) *Map[K, V] {
var hashCode = c.hasher.Hash(key)
var index = hashCode & (c.sharding - 1)
return c.buckets[index]
}

// Len returns the number of elements in the map
func (c *ConcurrentMap[K, V]) Len() int {
var length = 0
for _, b := range c.buckets {
b.Lock()
length += len(b.m)
length += b.Len()
b.Unlock()
}
return length
}

func (c *ConcurrentMap[K, V]) Load(key K) (value V, exist bool) {
var b = c.getBucket(key)
// Load returns the value stored in the map for a key, or nil if no
// value is present.
// The ok result indicates whether value was found in the map.
func (c *ConcurrentMap[K, V]) Load(key K) (value V, ok bool) {
var b = c.GetSharding(key)
b.Lock()
value, exist = b.m[key]
value, ok = b.Load(key)
b.Unlock()
return
}

// Delete deletes the value for a key.
func (c *ConcurrentMap[K, V]) Delete(key K) {
var b = c.getBucket(key)
var b = c.GetSharding(key)
b.Lock()
delete(b.m, key)
b.Delete(key)
b.Unlock()
}

// Store sets the value for a key.
func (c *ConcurrentMap[K, V]) Store(key K, value V) {
var b = c.getBucket(key)
var b = c.GetSharding(key)
b.Lock()
b.m[key] = value
b.Store(key, value)
b.Unlock()
}

// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
func (c *ConcurrentMap[K, V]) Range(f func(key K, value V) bool) {
for _, b := range c.buckets {
var next = true
var cb = func(k K, v V) bool {
next = f(k, v)
return next
}
for i := uint64(0); i < c.sharding && next; i++ {
var b = c.buckets[i]
b.Lock()
for k, v := range b.m {
if !f(k, v) {
b.Unlock()
return
}
}
b.Range(cb)
b.Unlock()
}
}

type Map[K comparable, V any] struct {
sync.Mutex
m map[K]V
}

func (c *Map[K, V]) Len() int { return len(c.m) }

func (c *Map[K, V]) Load(key K) (value V, ok bool) {
value, ok = c.m[key]
return
}

func (c *Map[K, V]) Delete(key K) { delete(c.m, key) }

func (c *Map[K, V]) Store(key K, value V) { c.m[key] = value }

func (c *Map[K, V]) Range(f func(K, V) bool) {
for k, v := range c.m {
if !f(k, v) {
return
}
}
}
9 changes: 9 additions & 0 deletions session_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ func TestConcurrentMap(t *testing.T) {
as.Equal(v, v1)
}
as.Equal(len(m1), m2.Len())

t.Run("", func(t *testing.T) {
var sum = 0
var cm = NewConcurrentMap[string, int](8, 8)
for _, item := range cm.buckets {
sum += len(item.m)
}
assert.Equal(t, sum, 0)
})
}

func TestConcurrentMap_Range(t *testing.T) {
Expand Down

0 comments on commit cfdae56

Please sign in to comment.