Skip to content

Commit

Permalink
增加支持ProxyProtocolV2
Browse files Browse the repository at this point in the history
  • Loading branch information
nnhy committed Nov 3, 2024
1 parent 3701f6c commit 6a0255f
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 8 deletions.
29 changes: 21 additions & 8 deletions NewLife.MQTT/ProxyProtocol/ProxyCodec.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Net;
using NewLife.Data;
using NewLife.Data;
using NewLife.Model;
using NewLife.Net;

Expand All @@ -14,7 +13,7 @@ public class ProxyCodec : Handler
/// <returns></returns>
public override Object? Read(IHandlerContext context, Object message)
{
if (message is IPacket pk)
if (message is IPacket pk && context is NetHandlerContext ctx && ctx.Data is ReceivedEventArgs e)
{
var data = pk.GetSpan();
if (ProxyMessage.FastValidHeader(data))
Expand All @@ -27,11 +26,25 @@ public class ProxyCodec : Handler
{
ext["Proxy"] = msg;

if (context is NetHandlerContext ctx && ctx.Data is ReceivedEventArgs e)
{
// 修改远程地址
e.Remote = msg.GetClient().EndPoint;
}
// 修改远程地址
e.Remote = msg.GetClient().EndPoint;
}

message = pk.Slice(rs);
}
}
else if (ProxyMessageV2.FastValidHeader(data))
{
var msg = new ProxyMessageV2();
var rs = msg.Read(data);
if (rs > 0)
{
if (context is IExtend ext)
{
ext["Proxy"] = msg;

// 修改远程地址
e.Remote = msg.Client!.EndPoint;
}

message = pk.Slice(rs);
Expand Down
171 changes: 171 additions & 0 deletions NewLife.MQTT/ProxyProtocol/ProxyMessageV2.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
using System.Net;
using System.Net.Sockets;
using NewLife.Buffers;
using NewLife.Data;
using NewLife.Net;

namespace NewLife.MQTT.ProxyProtocol;

/// <summary>ProxyProtocol协议消息</summary>
public class ProxyMessageV2
{
#region 属性
/// <summary>版本</summary>
public Byte Version { get; set; }

/// <summary>命令</summary>
public Byte Command { get; set; }

/// <summary>客户端</summary>
public NetUri? Client { get; set; }

/// <summary>代理端</summary>
public NetUri? Proxy { get; set; }
#endregion

#region 核心读写方法
private static readonly Byte[] _Magic = [(Byte)'\r', (Byte)'\n', (Byte)'\r', (Byte)'\n', 0, (Byte)'\r', (Byte)'\n', (Byte)'Q', (Byte)'U', (Byte)'I', (Byte)'T', (Byte)'\n'];
private static readonly Byte[] _NewLine = [(Byte)'\r', (Byte)'\n'];

/// <summary>快速验证协议头</summary>
/// <param name="data"></param>
/// <returns></returns>
public static Boolean FastValidHeader(ReadOnlySpan<Byte> data) => data.StartsWith(_Magic);

/// <summary>解析协议</summary>
/// <param name="data"></param>
/// <returns></returns>
public Int32 Read(ReadOnlySpan<Byte> data)
{
if (!data.StartsWith(_Magic)) return -1;

var reader = new SpanReader(data);
reader.Advance(_Magic.Length);

var flag = reader.ReadByte();
Version = (Byte)(flag >> 4);
Command = (Byte)(flag & 0x0F);

flag = reader.ReadByte();
var len = reader.ReadUInt16();

var family = (flag >> 4) switch
{
0 => AddressFamily.Unspecified,
1 => AddressFamily.InterNetwork,
2 => AddressFamily.InterNetworkV6,
3 => AddressFamily.Unix,
_ => AddressFamily.Unspecified,
};
var protocol = (flag & 0x0F) switch
{
0 => NetType.Unknown,
1 => NetType.Tcp,
2 => NetType.Udp,
_ => NetType.Unknown,
};

switch (family)
{
case AddressFamily.InterNetwork:
{
var src_addr = reader.ReadInt32();
var dst_addr = reader.ReadInt32();
var src_port = reader.ReadUInt16();
var dst_port = reader.ReadUInt16();

Client = new NetUri(protocol, new IPAddress(src_addr), src_port);
Proxy = new NetUri(protocol, new IPAddress(dst_addr), dst_port);
}
break;
case AddressFamily.InterNetworkV6:
{
var src_addr = reader.ReadBytes(16);
var dst_addr = reader.ReadBytes(16);
var src_port = reader.ReadUInt16();
var dst_port = reader.ReadUInt16();

Client = new NetUri(protocol, new IPAddress(src_addr.ToArray()), src_port);
Proxy = new NetUri(protocol, new IPAddress(dst_addr.ToArray()), dst_port);
}
break;
case AddressFamily.Unix:
{
//var src_addr = reader.ReadBytes(16);
//var dst_addr = reader.ReadBytes(16);

throw new NotSupportedException();
}
//break;
}

// 后续TLV数据
len = (UInt16)(_Magic.Length + len - reader.Position);
if (len > 0)
{
var vs = reader.ReadBytes(len);
//todo: 支持解析TLV数据
}

return reader.Position;
}

/// <summary>转为数据包</summary>
/// <returns></returns>
public IPacket ToPacket()
{
var pk = new OwnerPacket(256);
var writer = new SpanWriter(pk.GetSpan());

writer.Write(_Magic);

writer.WriteByte((Version << 4) | Command);

var src = Client;
var dst = Proxy;
var flag = 0;
switch (src!.Address.AddressFamily)
{
case AddressFamily.InterNetwork:
flag |= 0x10;
break;
case AddressFamily.InterNetworkV6:
flag |= 0x20;
break;
case AddressFamily.Unix:
flag |= 0x30;
break;
}
switch (src!.Type)
{
case NetType.Tcp:
flag |= 0x01;
break;
case NetType.Udp:
flag |= 0x02;
break;
}
writer.WriteByte(flag);

switch (src!.Address.AddressFamily)
{
case AddressFamily.InterNetwork:
writer.Write(src!.Address.GetAddressBytes());
writer.Write(dst!.Address.GetAddressBytes());
writer.Write((UInt16)src.Port);
writer.Write((UInt16)dst.Port);
break;
case AddressFamily.InterNetworkV6:
writer.Write(src!.Address.GetAddressBytes());
writer.Write(dst!.Address.GetAddressBytes());
writer.Write((UInt16)src.Port);
writer.Write((UInt16)dst.Port);
break;
case AddressFamily.Unix:
break;
}

return pk;
}
#endregion
}

0 comments on commit 6a0255f

Please sign in to comment.