-
Notifications
You must be signed in to change notification settings - Fork 98
/
session_serve.go
141 lines (118 loc) · 3.95 KB
/
session_serve.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package remotedialer
import (
"context"
"errors"
"fmt"
"io"
"github.com/sirupsen/logrus"
)
// serveMessage accepts an incoming message from the underlying websocket connection and processes the request based on its messageType
func (s *Session) serveMessage(ctx context.Context, reader io.Reader) error {
message, err := newServerMessage(reader)
if err != nil {
return err
}
if PrintTunnelData {
logrus.Debug("REQUEST ", message)
}
switch message.messageType {
case Connect:
return s.clientConnect(ctx, message)
case AddClient:
return s.addRemoteClient(message.address)
case RemoveClient:
return s.removeRemoteClient(message.address)
case SyncConnections:
return s.syncConnections(message.body)
case Data:
s.connectionData(message.connID, message.body)
case Pause:
s.pauseConnection(message.connID)
case Resume:
s.resumeConnection(message.connID)
case Error:
s.closeConnection(message.connID, message.Err())
}
return nil
}
// clientConnect accepts a new connection request, dialing back to establish the connection
func (s *Session) clientConnect(ctx context.Context, message *message) error {
if s.auth == nil || !s.auth(message.proto, message.address) {
return errors.New("connect not allowed")
}
conn := newConnection(message.connID, s, message.proto, message.address)
s.addConnection(message.connID, conn)
go clientDial(ctx, s.dialer, conn, message)
return nil
}
// / addRemoteClient registers a new remote client, making it accessible for requests
func (s *Session) addRemoteClient(address string) error {
if s.remoteClientKeys == nil {
return nil
}
clientKey, sessionKey, err := parseAddress(address)
if err != nil {
return fmt.Errorf("invalid remote Session %s: %v", address, err)
}
s.addSessionKey(clientKey, sessionKey)
if PrintTunnelData {
logrus.Debugf("ADD REMOTE CLIENT %s, SESSION %d", address, s.sessionKey)
}
return nil
}
// / addRemoteClient removes a given client from a session
func (s *Session) removeRemoteClient(address string) error {
clientKey, sessionKey, err := parseAddress(address)
if err != nil {
return fmt.Errorf("invalid remote Session %s: %v", address, err)
}
s.removeSessionKey(clientKey, sessionKey)
if PrintTunnelData {
logrus.Debugf("REMOVE REMOTE CLIENT %s, SESSION %d", address, s.sessionKey)
}
return nil
}
// syncConnections closes any session connection that is not present in the IDs received from the client
func (s *Session) syncConnections(r io.Reader) error {
payload, err := io.ReadAll(r)
if err != nil {
return fmt.Errorf("reading message body: %w", err)
}
clientActiveConnections, err := decodeConnectionIDs(payload)
if err != nil {
return fmt.Errorf("decoding sync connections payload: %w", err)
}
s.compareAndCloseStaleConnections(clientActiveConnections)
return nil
}
// closeConnection removes a connection for a given ID from the session, sending an error message to communicate the closing to the other end.
// If an error is not provided, io.EOF will be used instead.
func (s *Session) closeConnection(connID int64, err error) {
if conn := s.removeConnection(connID); conn != nil {
conn.tunnelClose(err)
}
}
// connectionData process incoming data from connection by reading the body into an internal readBuffer
func (s *Session) connectionData(connID int64, body io.Reader) {
conn := s.getConnection(connID)
if conn == nil {
errMsg := newErrorMessage(connID, fmt.Errorf("connection not found %s/%d/%d", s.clientKey, s.sessionKey, connID))
_, _ = errMsg.WriteTo(defaultDeadline(), s.conn)
return
}
if err := conn.OnData(body); err != nil {
s.closeConnection(connID, err)
}
}
// pauseConnection activates backPressure for a given connection ID
func (s *Session) pauseConnection(connID int64) {
if conn := s.getConnection(connID); conn != nil {
conn.OnPause()
}
}
// resumeConnection deactivates backPressure for a given connection ID
func (s *Session) resumeConnection(connID int64) {
if conn := s.getConnection(connID); conn != nil {
conn.OnResume()
}
}