Skip to content

Commit

Permalink
fix(bpf/ssl): first HTTPS request on the server side might not be cap…
Browse files Browse the repository at this point in the history
…tured (#259)
  • Loading branch information
hengyoush authored Jan 8, 2025
1 parent 6653fef commit 6d507da
Show file tree
Hide file tree
Showing 24 changed files with 253 additions and 78 deletions.
36 changes: 26 additions & 10 deletions agent/analysis/stat.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,33 +202,38 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect
annotatedRecord.ReqPlainTextSize = events.ingressMessage.ByteSize()
annotatedRecord.RespPlainTextSize = events.egressMessage.ByteSize()
}
canCalculateReadPathTime := !connection.IsSsl() || isKernEvtCanMatchSslEvt(events.sslReadSyscallEvents)
canCalculateWritePathTime := !connection.IsSsl() || isKernEvtCanMatchSslEvt(events.sslWriteSyscallEvents)
annotatedRecord.ReqSize = events.ingressKernLen
annotatedRecord.RespSize = events.egressKernLen
if annotatedRecord.StartTs != math.MaxUint64 && hasDevOutEvents {
if annotatedRecord.StartTs != math.MaxUint64 && hasDevOutEvents &&
(canCalculateReadPathTime && canCalculateWritePathTime) {
annotatedRecord.TotalDuration = float64(annotatedRecord.EndTs) - float64(annotatedRecord.StartTs)
}
if hasReadSyscallEvents && hasWriteSyscallEvents {
if hasReadSyscallEvents && hasWriteSyscallEvents && canCalculateReadPathTime && canCalculateWritePathTime {
annotatedRecord.BlackBoxDuration = float64(events.writeSyscallEvents[len(events.writeSyscallEvents)-1].GetEndTs()) - float64(events.readSyscallEvents[0].GetStartTs())
} else {
annotatedRecord.BlackBoxDuration = float64(events.egressMessage.TimestampNs()) - float64(events.ingressMessage.TimestampNs())
}
if hasUserCopyEvents && hasTcpInEvents {
if hasUserCopyEvents && hasTcpInEvents && canCalculateReadPathTime {
annotatedRecord.ReadFromSocketBufferDuration = float64(events.userCopyEvents[len(events.userCopyEvents)-1].GetStartTs()) - float64(events.tcpInEvents[0].GetStartTs())
}
if hasTcpInEvents && hasNicInEvents {
if hasTcpInEvents && hasNicInEvents && canCalculateWritePathTime {
annotatedRecord.CopyToSocketBufferDuration = float64(events.tcpInEvents[len(events.tcpInEvents)-1].GetStartTs() - events.nicIngressEvents[0].GetStartTs())
}
annotatedRecord.ReqSyscallEventDetails = KernEventsToEventDetails[analysisCommon.SyscallEventDetail](events.readSyscallEvents)
annotatedRecord.RespSyscallEventDetails = KernEventsToEventDetails[analysisCommon.SyscallEventDetail](events.writeSyscallEvents)
annotatedRecord.ReqNicEventDetails = KernEventsToNicEventDetails(events.nicIngressEvents)
annotatedRecord.RespNicEventDetails = KernEventsToNicEventDetails(events.devOutEvents)
} else {
if hasWriteSyscallEvents {
canCalculateReadPathTime := !connection.IsSsl() || isKernEvtCanMatchSslEvt(events.sslReadSyscallEvents)
canCalculateWritePathTime := !connection.IsSsl() || isKernEvtCanMatchSslEvt(events.sslWriteSyscallEvents)
if hasWriteSyscallEvents && canCalculateWritePathTime {
annotatedRecord.StartTs = findMinTimestamp(events.writeSyscallEvents, true)
} else {
annotatedRecord.StartTs = events.egressMessage.TimestampNs()
}
if hasReadSyscallEvents {
if hasReadSyscallEvents && canCalculateReadPathTime {
annotatedRecord.EndTs = findMaxTimestamp(events.readSyscallEvents, false)
} else {
annotatedRecord.EndTs = events.ingressMessage.TimestampNs()
Expand All @@ -239,12 +244,12 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect
}
annotatedRecord.ReqSize = events.egressKernLen
annotatedRecord.RespSize = events.ingressKernLen
if hasReadSyscallEvents && hasWriteSyscallEvents {
if hasReadSyscallEvents && hasWriteSyscallEvents && canCalculateReadPathTime && canCalculateWritePathTime {
annotatedRecord.TotalDuration = float64(annotatedRecord.EndTs) - float64(annotatedRecord.StartTs)
} else {
annotatedRecord.TotalDuration = float64(events.ingressMessage.TimestampNs()) - float64(events.egressMessage.TimestampNs())
}
if hasNicInEvents && hasDevOutEvents {
if hasNicInEvents && hasDevOutEvents && canCalculateReadPathTime && canCalculateWritePathTime {
nicIngressTimestamp := int64(0)
for _, nicIngressEvent := range events.nicIngressEvents {
_nicIngressTimestamp, _, ok := nicIngressEvent.GetMinIfItmestampAttr()
Expand All @@ -271,7 +276,7 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect
annotatedRecord.BlackBoxDuration = -1
}
}
if (hasUserCopyEvents || hasReadSyscallEvents) && hasTcpInEvents {
if (hasUserCopyEvents || hasReadSyscallEvents) && hasTcpInEvents && canCalculateReadPathTime {
var readFromEndTime float64
if hasUserCopyEvents {
readFromEndTime = float64(events.userCopyEvents[len(events.userCopyEvents)-1].GetStartTs())
Expand All @@ -280,7 +285,7 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect
}
annotatedRecord.ReadFromSocketBufferDuration = readFromEndTime - float64(events.tcpInEvents[0].GetStartTs())
}
if hasTcpInEvents && hasNicInEvents {
if hasTcpInEvents && hasNicInEvents && canCalculateReadPathTime {
annotatedRecord.CopyToSocketBufferDuration = float64(events.tcpInEvents[len(events.tcpInEvents)-1].GetStartTs() - events.nicIngressEvents[0].GetStartTs())
}
annotatedRecord.ReqSyscallEventDetails = KernEventsToEventDetails[analysisCommon.SyscallEventDetail](events.writeSyscallEvents)
Expand Down Expand Up @@ -319,6 +324,17 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect
return nil
}

// some syscalls are not nested int ssl events, so we need to check if all ssl events have kernLen>0
// otherwise, we can't calculate the duration related to kern events because the kern seq is not valid
func isKernEvtCanMatchSslEvt(events []conn.SslEvent) bool {
for _, each := range events {
if each.KernLen == 0 {
return false
}
}
return true
}

func findMaxTimestamp(events []conn.KernEvent, useStartTs bool) uint64 {
var maxTimestamp uint64 = 0
for _, each := range events {
Expand Down
26 changes: 15 additions & 11 deletions agent/conn/conntrack.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type Connection4 struct {

ssl bool

tracable bool
tracable bpf.AgentConnTraceStateT
onRoleChanged func()

TempKernEvents []*bpf.AgentKernEvt
Expand Down Expand Up @@ -75,7 +75,7 @@ func NewConnFromEvent(event *bpf.AgentConnEvtT, p *Processor) *Connection4 {
Role: event.ConnInfo.Role,
TgidFd: TgidFd,
Status: Connected,
tracable: true,
tracable: bpf.AgentConnTraceStateTUnset,

MessageFilter: p.messageFilter,
LatencyFilter: p.latencyFilter,
Expand Down Expand Up @@ -330,37 +330,37 @@ func (c *Connection4) OnClose(needClearBpfMap bool) {
monitor.UnregisterMetricExporter(c.StreamEvents)
}

func (c *Connection4) UpdateConnectionTraceable(traceable bool) {
if c.tracable == traceable {
func (c *Connection4) UpdateConnectionTraceable(traceableState bpf.AgentConnTraceStateT) {
if c.tracable == traceableState {
return
}
c.tracable = traceable
c.tracable = traceableState
key, _ := c.extractSockKeys()
sockKeyConnIdMap := bpf.GetMapFromObjs(bpf.Objs, "SockKeyConnIdMap")
c.doUpdateConnIdMapProtocolToUnknwon(key, sockKeyConnIdMap, traceable)
c.doUpdateConnIdMapProtocolToUnknwon(key, sockKeyConnIdMap, traceableState)
// c.doUpdateConnIdMapProtocolToUnknwon(revKey, sockKeyConnIdMap, traceable)

connInfoMap := bpf.GetMapFromObjs(bpf.Objs, "ConnInfoMap")
connInfo := bpf.AgentConnInfoT{}
err := connInfoMap.Lookup(c.TgidFd, &connInfo)
if err == nil {
connInfo.NoTrace = !traceable
connInfo.NoTrace = traceableState
connInfoMap.Update(c.TgidFd, &connInfo, ebpf.UpdateExist)
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("try to update %s conn_info_map to traceable: %v success!", c.ToString(), traceable)
common.ConntrackLog.Debugf("try to update %s conn_info_map to traceable: %v success!", c.ToString(), traceableState)
}
} else {
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("try to update %s conn_info_map to traceable: %v, but no entry in map found!", c.ToString(), traceable)
common.ConntrackLog.Debugf("try to update %s conn_info_map to traceable: %v, but no entry in map found!", c.ToString(), traceableState)
}
}
}

func (c *Connection4) doUpdateConnIdMapProtocolToUnknwon(key bpf.AgentSockKey, m *ebpf.Map, traceable bool) {
func (c *Connection4) doUpdateConnIdMapProtocolToUnknwon(key bpf.AgentSockKey, m *ebpf.Map, traceable bpf.AgentConnTraceStateT) {
var connIds bpf.AgentConnIdS_t
err := m.Lookup(&key, &connIds)
if err == nil {
connIds.NoTrace = !traceable
connIds.NoTrace = traceable
m.Update(&key, &connIds, ebpf.UpdateExist)
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("try to update %s conn_id_map to traceable: %v, success, sock key: %v", c.ToString(), traceable, key)
Expand All @@ -372,6 +372,10 @@ func (c *Connection4) doUpdateConnIdMapProtocolToUnknwon(key bpf.AgentSockKey, m
}
}

func (c *Connection4) IsTraceble() bool {
return c.tracable <= bpf.AgentConnTraceStateTTraceable
}

// func (c *Connection4) OnCloseWithoutClearBpfMap() {
// c.OnClose(false)
// }
Expand Down
4 changes: 2 additions & 2 deletions agent/conn/first_packet_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (p *FirstPacketProcessor) processEvent(event *timedFirstPacketEvent) {
channel := p.channels[int(conn.(*Connection4).TgidFd)%len(p.channels)]
connId := &bpf.AgentConnIdS_t{
TgidFd: conn.(*Connection4).TgidFd,
NoTrace: false,
NoTrace: conn.(*Connection4).tracable,
}
common.BPFEventLog.Debugf("%s First packet event: %+v", conn.(*Connection4).ToString(), event.FirstPacketEvent)
kernEvent := timedFirstPacketEventAsKernEvent(event, connId)
Expand All @@ -90,7 +90,7 @@ func (p *FirstPacketProcessor) extractTgidFdFromSockKey(key *bpf.AgentSockKey) (
sockKeyConnIdMap := bpf.GetMapFromObjs(bpf.Objs, "SockKeyConnIdMap")
var connIds bpf.AgentConnIdS_t
err := sockKeyConnIdMap.Lookup(key, &connIds)
if err == nil && !connIds.NoTrace {
if err == nil && connIds.NoTrace <= bpf.AgentConnTraceStateTTraceable {
return &connIds, nil
}
return nil, err
Expand Down
20 changes: 14 additions & 6 deletions agent/conn/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ func (p *Processor) run() {
// previousProtocol := conn.Protocol
if conn != nil && conn.Status != Closed {
conn.Protocol = event.ConnInfo.Protocol
common.ConntrackLog.Debugf("[protocol-infer][%s] protocol updated: %d", conn.ToString(), conn.Protocol)
} else {
if conn == nil {
missedConn := NewConnFromEvent(event, p)
Expand All @@ -226,6 +227,7 @@ func (p *Processor) run() {
p.connManager.AddConnection4(TgidFd, missedConn)
conn = missedConn
} else {
common.ConntrackLog.Debugf("[protocol-infer][%s] protocol not updated: %d", conn.ToString(), conn.Protocol)
continue
}
}
Expand Down Expand Up @@ -258,7 +260,7 @@ func (p *Processor) run() {
}
}
conn.TempSslEvents = conn.TempSslEvents[0:0]
conn.UpdateConnectionTraceable(true)
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTTraceable)
// handle kern events
for _, kernEvent := range conn.TempKernEvents {
if conn.timeBoundCheck(kernEvent.Ts) {
Expand All @@ -275,7 +277,13 @@ func (p *Processor) run() {
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("%s discarded due to not interested, isProtocolInterested: %v, isSideNotMatched:%v", conn.ToString(), isProtocolInterested, isSideNotMatched(p, conn))
}
conn.UpdateConnectionTraceable(false)
if conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnknown {
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTProtocolUnknown)
} else if !isProtocolInterested {
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTProtocolNotMatched)
} else {
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTOther)
}
// conn.OnClose(true)
}
}
Expand Down Expand Up @@ -464,7 +472,7 @@ func (p *Processor) processSyscallEvent(event *bpf.SyscallEventData, recordChann
}
return
}
if conn != nil && !conn.tracable {
if conn != nil && !conn.IsTraceble() {
if common.BPFEventLog.Level >= logrus.DebugLevel {
common.BPFEventLog.Debugf("[syscall][no-trace][len=%d][ts=%d]%s | %s", event.SyscallEvent.BufSize, event.SyscallEvent.Ke.Ts, conn.ToString(), string(event.Buf))
}
Expand Down Expand Up @@ -538,7 +546,7 @@ func (p *Processor) processSslEvent(event *bpf.SslData, recordChannel chan Recor
}
return
}
if conn != nil && !conn.tracable {
if conn != nil && !conn.IsTraceble() {
conn.AddSslEvent(event)
if common.BPFEventLog.Level >= logrus.DebugLevel {
common.BPFEventLog.Debugf("[ssl][no-trace][len=%d][ts=%d]%s | %s", event.SslEventHeader.BufSize, event.SslEventHeader.Ke.Ts, conn.ToString(), string(event.Buf))
Expand Down Expand Up @@ -576,12 +584,12 @@ func onRoleChanged(p *Processor, conn *Connection4) {
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("[onRoleChanged] %s discarded due to not matched by side", conn.ToString())
}
conn.UpdateConnectionTraceable(false)
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTOther)
} else {
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("[onRoleChanged] %s actived due to matched by side", conn.ToString())
}
conn.UpdateConnectionTraceable(true)
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTTraceable)
}
}

Expand Down
19 changes: 15 additions & 4 deletions bpf/agent_arm64_bpfel.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 15 additions & 4 deletions bpf/agent_x86_bpfel.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 12 additions & 8 deletions bpf/data_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ static __inline bool should_trace_conn(struct conn_info_t *conn_info) {
// return true;
// }

return conn_info->protocol != kProtocolUnknown && !conn_info->no_trace;
return conn_info->protocol != kProtocolUnknown && conn_info->no_trace <= traceable ;
}

static void __always_inline report_syscall_buf_without_data(void* ctx, uint64_t seq, struct conn_id_s_t *conn_id_s, size_t len, enum step_t step, uint64_t ts, uint32_t ts_delta, enum source_function_t source_fn) {
Expand Down Expand Up @@ -276,7 +276,7 @@ static __always_inline void process_sendfile_with_conn_info(void* ctx, struct se
} else {
step = direct == kEgress ? SYSCALL_OUT : SYSCALL_IN;
}
if (conn_info->protocol != kProtocolUnknown && (!conn_info->no_trace)) {//, bytes_count
if (conn_info->protocol != kProtocolUnknown && ( conn_info->no_trace <= traceable)) {//, bytes_count
report_syscall_buf_without_data(ctx, seq, &conn_id_s, bytes_count, step, args->start_ts, args->end_ts - args->start_ts, kSyscallSendfile);
}
}
Expand All @@ -286,12 +286,12 @@ static __always_inline void process_syscall_data_with_conn_info(void* ctx, struc
bool inferred = false;
if ((conn_info->protocol == kProtocolUnset || conn_info->protocol == kProtocolUnknown) && with_data) {
enum traffic_protocol_t before_infer = conn_info->protocol;
// bpf_printk("[protocol infer]:start, bc:%d", bytes_count);
// bpf_printk("SSL[protocol infer]:start, bc:%d", bytes_count);
// conn_info->protocol = protocol_message.protocol;
struct protocol_message_t protocol_message = infer_protocol(args->buf, bytes_count, conn_info);
if (before_infer != protocol_message.protocol) {
conn_info->protocol = protocol_message.protocol;
// bpf_printk("[protocol infer]: %d, func: %d", conn_info->protocol, args->source_fn);
// bpf_printk("SSL[protocol infer]: %d, func: %d", conn_info->protocol, args->source_fn);

if (conn_info->role == kRoleUnknown && protocol_message.type != kUnknown) {
conn_info->role = ((direct == kEgress) ^ (protocol_message.type == kResponse))
Expand All @@ -313,20 +313,24 @@ static __always_inline void process_syscall_data_with_conn_info(void* ctx, struc
} else {
step = direct == kEgress ? SYSCALL_OUT : SYSCALL_IN;
}

if (conn_info->protocol != kProtocolUnknown && (inferred || !conn_info->no_trace)) {//, bytes_count

if (conn_info->protocol != kProtocolUnknown && (inferred || conn_info->no_trace <= traceable) ||
// condition below is for the case when protocol is already inffered in previous syscall
// but user space have not yet updated the conn_info.no_trace to traceable.
// so when conn_info.protocol is not unknown but the cause of trace state is unknown, we still trace data.
(conn_info->protocol != kProtocolUnknown && conn_info->no_trace == protocol_unknown)) {
if (is_ssl) {
uint64_t syscall_seq = (direct == kEgress ? conn_info->write_bytes : conn_info->read_bytes) + 1;
seq = (direct == kEgress ? conn_info->ssl_write_bytes : conn_info->ssl_read_bytes) + 1;
report_ssl_evt(ctx, seq, &conn_id_s, bytes_count, step, args, syscall_len < 0 ? 0 : (syscall_seq - syscall_len), syscall_len < 0 ? 0 : syscall_len);
// bpf_printk("report ssl evt, seq: %lld len: %d", seq, bytes_count);
// bpf_printk("SSLreport ssl evt, seq: %lld len: %d, syscall_len:%d", seq, bytes_count, syscall_len);
} else if (with_data) {
report_syscall_evt(ctx, seq, &conn_id_s, bytes_count, step, args);
} else {
report_syscall_buf_without_data(ctx, seq, &conn_id_s, bytes_count, step, args->start_ts, args->end_ts - args->start_ts, args->source_fn);
}
} else {
// bpf_printk("no trace, bytes_count:%d", bytes_count);
// bpf_printk("SSLno trace, bytes_count:%d,p:%d,infer:%d", bytes_count,conn_info->protocol,inferred);
}
}

Expand Down
Loading

0 comments on commit 6d507da

Please sign in to comment.