Skip to content

Commit

Permalink
feat(protocol): introduce the concept of streams to prepare for futur…
Browse files Browse the repository at this point in the history
…e support of HTTP2 and Mongo (#258)

Signed-off-by: 烈香 <[email protected]>
  • Loading branch information
hengyoush authored Jan 9, 2025
1 parent 6d507da commit 05b2c40
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 30 deletions.
31 changes: 19 additions & 12 deletions agent/conn/conntrack.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ type Connection4 struct {

reqStreamBuffer *buffer.StreamBuffer
respStreamBuffer *buffer.StreamBuffer
ReqQueue []protocol.ParsedMessage
ReqQueue map[protocol.StreamId]*protocol.ParsedMessageQueue
RespQueue map[protocol.StreamId]*protocol.ParsedMessageQueue
lastReqMadeProgressTime int64
lastRespMadeProgressTime int64
RespQueue []protocol.ParsedMessage
StreamEvents *KernEventStream
protocolParsers map[bpf.AgentTrafficProtocolT]protocol.ProtocolStreamParser

Expand Down Expand Up @@ -83,8 +83,8 @@ func NewConnFromEvent(event *bpf.AgentConnEvtT, p *Processor) *Connection4 {

reqStreamBuffer: buffer.New(1024 * 1024),
respStreamBuffer: buffer.New(1024 * 1024),
ReqQueue: make([]protocol.ParsedMessage, 0),
RespQueue: make([]protocol.ParsedMessage, 0),
ReqQueue: make(map[protocol.StreamId]*protocol.ParsedMessageQueue),
RespQueue: make(map[protocol.StreamId]*protocol.ParsedMessageQueue),

prevConn: []*Connection4{},

Expand Down Expand Up @@ -450,8 +450,8 @@ func (c *Connection4) addDataToBufferAndTryParse(data []byte, ke *bpf.AgentKernE
if c.Role == bpf.AgentEndpointRoleTKRoleUnknown {
respSteamMessageType = protocol.Unknown
}
c.parseStreamBuffer(c.reqStreamBuffer, reqSteamMessageType, &c.ReqQueue, ke)
c.parseStreamBuffer(c.respStreamBuffer, respSteamMessageType, &c.RespQueue, ke)
c.parseStreamBuffer(c.reqStreamBuffer, reqSteamMessageType, c.ReqQueue, ke)
c.parseStreamBuffer(c.respStreamBuffer, respSteamMessageType, c.RespQueue, ke)
return true
}
func (c *Connection4) OnSslDataEvent(data []byte, event *bpf.SslData, recordChannel chan RecordWithConn) {
Expand All @@ -467,7 +467,7 @@ func (c *Connection4) OnSslDataEvent(data []byte, event *bpf.SslData, recordChan
return
}

records := parser.Match(&c.ReqQueue, &c.RespQueue)
records := parser.Match(c.ReqQueue, c.RespQueue)
if len(records) != 0 {
for _, record := range records {
recordChannel <- RecordWithConn{record, c}
Expand Down Expand Up @@ -501,7 +501,7 @@ func (c *Connection4) OnSyscallEvent(data []byte, event *bpf.SyscallEventData, r
panic("no protocol parser!")
}

records := parser.Match(&c.ReqQueue, &c.RespQueue)
records := parser.Match(c.ReqQueue, c.RespQueue)
if len(records) != 0 {
for _, record := range records {
recordChannel <- RecordWithConn{record, c}
Expand All @@ -510,7 +510,7 @@ func (c *Connection4) OnSyscallEvent(data []byte, event *bpf.SyscallEventData, r
return true
}

func (c *Connection4) parseStreamBuffer(streamBuffer *buffer.StreamBuffer, messageType protocol.MessageType, resultQueue *[]protocol.ParsedMessage, ke *bpf.AgentKernEvt) {
func (c *Connection4) parseStreamBuffer(streamBuffer *buffer.StreamBuffer, messageType protocol.MessageType, resultQueue map[protocol.StreamId]*protocol.ParsedMessageQueue, ke *bpf.AgentKernEvt) {
parser := c.GetProtocolParser(c.Protocol)
if parser == nil {
return
Expand Down Expand Up @@ -551,7 +551,14 @@ func (c *Connection4) parseStreamBuffer(streamBuffer *buffer.StreamBuffer, messa
if len(parseResult.ParsedMessages) > 0 && parseResult.ParsedMessages[0].IsReq() != (messageType == protocol.Request) {
streamBuffer.RemovePrefix(parseResult.ReadBytes)
} else {
*resultQueue = append(*resultQueue, parseResult.ParsedMessages...)
for _, parsedMessage := range parseResult.ParsedMessages {
streamId := parsedMessage.StreamId()
if resultQueue[streamId] == nil {
queue := protocol.ParsedMessageQueue(make([]protocol.ParsedMessage, 0))
resultQueue[streamId] = &queue
}
*resultQueue[streamId] = append(*resultQueue[streamId], parsedMessage)
}
streamBuffer.RemovePrefix(parseResult.ReadBytes)
}
}
Expand Down Expand Up @@ -748,6 +755,6 @@ func (c *Connection4) GetProtocolParser(p bpf.AgentTrafficProtocolT) protocol.Pr
func (c *Connection4) resetParseProgress() {
c.reqStreamBuffer.Clear()
c.respStreamBuffer.Clear()
c.ReqQueue = c.ReqQueue[:]
c.RespQueue = c.RespQueue[:]
c.ReqQueue = make(map[protocol.StreamId]*protocol.ParsedMessageQueue)
c.RespQueue = make(map[protocol.StreamId]*protocol.ParsedMessageQueue)
}
2 changes: 1 addition & 1 deletion agent/protocol/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"kyanos/agent/buffer"
)

func matchByTimestamp(reqStream *[]ParsedMessage, respStream *[]ParsedMessage) []Record {
func matchByTimestamp(reqStream *ParsedMessageQueue, respStream *ParsedMessageQueue) []Record {
if len(*reqStream) == 0 || len(*respStream) == 0 {
return nil
}
Expand Down
15 changes: 14 additions & 1 deletion agent/protocol/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ var HTTP_BOUNDARY_MARKER = "\r\n\r\n"
type HTTPStreamParser struct {
}

func (h *HTTPStreamParser) Match(reqStream *[]ParsedMessage, respStream *[]ParsedMessage) []Record {
func (h *HTTPStreamParser) Match(reqStreams map[StreamId]*ParsedMessageQueue, respStreams map[StreamId]*ParsedMessageQueue) []Record {
reqStream, ok1 := reqStreams[0]
respStream, ok2 := respStreams[0]
if !ok1 || !ok2 {
return []Record{}
}
return matchByTimestamp(reqStream, respStream)
}

Expand Down Expand Up @@ -214,6 +219,10 @@ func (req *ParsedHttpRequest) IsReq() bool {
return true
}

func (req *ParsedHttpRequest) StreamId() StreamId {
return 0
}

type ParsedHttpResponse struct {
FrameBase
buf []byte
Expand All @@ -233,6 +242,10 @@ func (resp *ParsedHttpResponse) IsReq() bool {
return false
}

func (req *ParsedHttpResponse) StreamId() StreamId {
return 0
}

var _ ProtocolFilter = HttpFilter{}

type HttpFilter struct {
Expand Down
11 changes: 8 additions & 3 deletions agent/protocol/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ func (m *MysqlParser) FindBoundary(streamBuffer *buffer.StreamBuffer, messageTyp
return -1
}

func (m *MysqlParser) Match(reqStream *[]ParsedMessage, respStream *[]ParsedMessage) []Record {
func (m *MysqlParser) Match(reqStreams map[StreamId]*ParsedMessageQueue, respStreams map[StreamId]*ParsedMessageQueue) []Record {
reqStream, ok1 := reqStreams[0]
respStream, ok2 := respStreams[0]
if !ok1 || !ok2 {
return []Record{}
}
records := make([]Record, 0)
for len(*reqStream) != 0 {
reqPacket := (*reqStream)[0].(*MysqlPacket)
Expand Down Expand Up @@ -112,13 +117,13 @@ func (m *MysqlParser) Match(reqStream *[]ParsedMessage, respStream *[]ParsedMess
return records
}

func syncRespQueue(reqPacket *MysqlPacket, respStream *[]ParsedMessage) {
func syncRespQueue(reqPacket *MysqlPacket, respStream *ParsedMessageQueue) {
for len(*respStream) != 0 && (*respStream)[0].TimestampNs() < reqPacket.TimestampNs() {
*respStream = (*respStream)[1:]
}
}

func getRespView(reqStream *[]ParsedMessage, respStream *[]ParsedMessage) []ParsedMessage {
func getRespView(reqStream *ParsedMessageQueue, respStream *ParsedMessageQueue) []ParsedMessage {
count := 0
for _, resp := range *respStream {
if len(*reqStream) > 1 && resp.TimestampNs() > (*reqStream)[1].TimestampNs() {
Expand Down
8 changes: 8 additions & 0 deletions agent/protocol/mysql/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ func (m *MysqlResponse) IsReq() bool {
return false
}

func (m *MysqlResponse) StreamId() StreamId {
return 0
}

const (
Unknwon RespStatus = iota
None
Expand Down Expand Up @@ -298,6 +302,10 @@ func (m *MysqlPacket) IsReq() bool {
return m.isReq
}

func (req *MysqlPacket) StreamId() StreamId {
return 0
}

type MysqlRequestPacket struct {
MysqlPacket
cmd byte
Expand Down
8 changes: 6 additions & 2 deletions agent/protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import (

type ProtocolCreator func() ProtocolStreamParser

type StreamId int64

type ParsedMessageQueue []ParsedMessage

var ParsersMap map[bpf.AgentTrafficProtocolT]ProtocolCreator = make(map[bpf.AgentTrafficProtocolT]ProtocolCreator)

// TODO 修改未每一个processor有自己的parser
func GetParserByProtocol(protocol bpf.AgentTrafficProtocolT) ProtocolStreamParser {
parserCreator, ok := ParsersMap[protocol]
if ok {
Expand Down Expand Up @@ -60,7 +63,7 @@ func (r *Record) String(opt RecordToStringOptions) string {
type ProtocolStreamParser interface {
ParseStream(streamBuffer *buffer.StreamBuffer, messageType MessageType) ParseResult
FindBoundary(streamBuffer *buffer.StreamBuffer, messageType MessageType, startPos int) int
Match(reqStream *[]ParsedMessage, respStream *[]ParsedMessage) []Record
Match(reqStreams map[StreamId]*ParsedMessageQueue, respStreams map[StreamId]*ParsedMessageQueue) []Record
}

type ParsedMessage interface {
Expand All @@ -69,6 +72,7 @@ type ParsedMessage interface {
ByteSize() int
IsReq() bool
Seq() uint64
StreamId() StreamId
}

type ParseState int
Expand Down
11 changes: 10 additions & 1 deletion agent/protocol/redis..go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ func (m *RedisMessage) FormatToString() string {
return fmt.Sprintf("base=[%s] command=[%s] payload=[%s]", m.FrameBase.String(), m.command, m.payload)
}

func (req *RedisMessage) StreamId() StreamId {
return 0
}

func (r *RedisStreamParser) FindBoundary(streamBuffer *buffer.StreamBuffer, messageType MessageType, startPos int) int {
head := streamBuffer.Head().Buffer()
for ; startPos < len(head); startPos++ {
Expand All @@ -374,7 +378,12 @@ func (r *RedisStreamParser) FindBoundary(streamBuffer *buffer.StreamBuffer, mess
return -1
}

func (r *RedisStreamParser) Match(reqStream *[]ParsedMessage, respStream *[]ParsedMessage) []Record {
func (r *RedisStreamParser) Match(reqStreams map[StreamId]*ParsedMessageQueue, respStreams map[StreamId]*ParsedMessageQueue) []Record {
reqStream, ok1 := reqStreams[0]
respStream, ok2 := respStreams[0]
if !ok1 || !ok2 {
return []Record{}
}
return matchByTimestamp(reqStream, respStream)
}
func ParseSize(decoder *BinaryDecoder) (int, error) {
Expand Down
11 changes: 10 additions & 1 deletion agent/protocol/rocketmq/rocketmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func (r *RocketMQMessage) IsReq() bool {
return r.isReq
}

func (r *RocketMQMessage) StreamId() protocol.StreamId {
return 0
}

func (r *RocketMQStreamParser) ParseStream(streamBuffer *buffer.StreamBuffer, messageType protocol.MessageType) protocol.ParseResult {
buffer := streamBuffer.Head().Buffer()
common.ProtocolParserLog.Debugf("ParseStream received buffer length: %d", len(buffer))
Expand Down Expand Up @@ -279,7 +283,12 @@ func (r *RocketMQStreamParser) FindBoundary(streamBuffer *buffer.StreamBuffer, m
return -1
}

func (r *RocketMQStreamParser) Match(reqStream *[]protocol.ParsedMessage, respStream *[]protocol.ParsedMessage) []protocol.Record {
func (r *RocketMQStreamParser) Match(reqStreams map[protocol.StreamId]*protocol.ParsedMessageQueue, respStreams map[protocol.StreamId]*protocol.ParsedMessageQueue) []protocol.Record {
reqStream, ok1 := reqStreams[0]
respStream, ok2 := respStreams[0]
if !ok1 || !ok2 {
return []protocol.Record{}
}
common.ProtocolParserLog.Debugf("Matching %d requests with %d responses.", len(*reqStream), len(*respStream))
records := []protocol.Record{}

Expand Down
4 changes: 3 additions & 1 deletion docs/cn/how-to-add-a-new-protocol.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ type ParsedMessage interface {
ByteSize() int
IsReq() bool
Seq() uint64
StreamId() StreamId
}
```

Expand All @@ -92,6 +93,7 @@ type ParsedMessage interface {
| `ByteSize()` | 返回消息的字节大小。 |
| `IsReq()` | 判断消息是否为请求。 |
| `Seq()` | 返回消息的字节流序列号, 可以从 `streamBuffer.Head().LeftBoundary()` 获取。 |
| `StreamId()` | 返回消息的 StreamId, 一般的协议可以直接返回 0,用于 HTTP2 这种多路复用的协议。 |

HTTP 的例子:

Expand All @@ -118,7 +120,7 @@ type ParsedHttpRequest struct {
type ProtocolStreamParser interface {
ParseStream(streamBuffer *buffer.StreamBuffer, messageType MessageType) ParseResult
FindBoundary(streamBuffer *buffer.StreamBuffer, messageType MessageType, startPos int) int
Match(reqStream *[]ParsedMessage, respStream *[]ParsedMessage) []Record
Match(reqStream *ParsedMessageQueue, respStream *ParsedMessageQueue) []Record
}
```

Expand Down
17 changes: 9 additions & 8 deletions docs/how-to-add-a-new-protocol.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,14 @@ type ParsedMessage interface {
}
```

| Method Name | Function |
| ------------------ | ----------------------------------------------------------------------------------------------------- |
| `FormatToString()` | Formats the message into a string representation. |
| `TimestampNs()` | Returns the timestamp of the message (in nanoseconds). |
| `ByteSize()` | Returns the byte size of the message. |
| `IsReq()` | Determines if the message is a request. |
| `Seq()` | Returns the sequence number of the byte stream.(Obtain Seq from `streamBuffer.Head().LeftBoundary()`) |
| Method Name | Function |
| ------------------ | --------------------------------------------------------------------------------------------------------------------------------------- |
| `FormatToString()` | Formats the message into a string representation. |
| `TimestampNs()` | Returns the timestamp of the message (in nanoseconds). |
| `ByteSize()` | Returns the byte size of the message. |
| `IsReq()` | Determines if the message is a request. |
| `Seq()` | Returns the sequence number of the byte stream.(Obtain Seq from `streamBuffer.Head().LeftBoundary()`) |
| `StreamId()` | Return the StreamId of the message. For most protocols, you can directly return 0. This is used for multiplexed protocols like HTTP2.。 |

Example for HTTP:

Expand Down Expand Up @@ -142,7 +143,7 @@ interface:
type ProtocolStreamParser interface {
ParseStream(streamBuffer *buffer.StreamBuffer, messageType MessageType) ParseResult
FindBoundary(streamBuffer *buffer.StreamBuffer, messageType MessageType, startPos int) int
Match(reqStream *[]ParsedMessage, respStream *[]ParsedMessage) []Record
Match(reqStream *ParsedMessageQueue, respStream *ParsedMessageQueue) []Record
}
```

Expand Down

0 comments on commit 05b2c40

Please sign in to comment.