From 2e074b92c2a148b2571a5caa2200ff94d6dc52ee Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 17 Oct 2021 23:37:28 +0100 Subject: [PATCH] Fix crash on map change (#74) --- source/{main.hpp => debug.hpp} | 12 - source/filecheck.cpp | 4 +- source/filecheck.hpp | 2 +- source/main.cpp | 12 +- source/netfilter/client.cpp | 2 +- source/netfilter/clientmanager.cpp | 2 +- source/netfilter/core.cpp | 1356 +++++++++++++++------------- source/netfilter/core.hpp | 2 +- 8 files changed, 747 insertions(+), 645 deletions(-) rename source/{main.hpp => debug.hpp} (65%) diff --git a/source/main.hpp b/source/debug.hpp similarity index 65% rename from source/main.hpp rename to source/debug.hpp index bff289c..00a37e0 100644 --- a/source/main.hpp +++ b/source/debug.hpp @@ -1,10 +1,5 @@ #pragma once -#include - -#include -#include - #if defined DEBUG #include @@ -19,10 +14,3 @@ #define _DebugWarning( ... ) #endif - -class IServer; - -namespace global -{ - extern IServer *server; -} diff --git a/source/filecheck.cpp b/source/filecheck.cpp index 0df128d..929c0cd 100644 --- a/source/filecheck.cpp +++ b/source/filecheck.cpp @@ -1,5 +1,5 @@ #include "filecheck.hpp" -#include "main.hpp" +#include "debug.hpp" #include #include @@ -162,7 +162,7 @@ namespace filecheck LUA->SetField( -2, "EnableFileValidation" ); } - void Deinitialize( GarrysMod::Lua::ILuaBase * ) + void Deinitialize( ) { hook.Destroy( ); } diff --git a/source/filecheck.hpp b/source/filecheck.hpp index e67d200..7e42b95 100644 --- a/source/filecheck.hpp +++ b/source/filecheck.hpp @@ -11,5 +11,5 @@ namespace GarrysMod namespace filecheck { void Initialize( GarrysMod::Lua::ILuaBase *LUA ); - void Deinitialize( GarrysMod::Lua::ILuaBase *LUA ); + void Deinitialize( ); } diff --git a/source/main.cpp b/source/main.cpp index 9471b3c..16bbf3f 100644 --- a/source/main.cpp +++ b/source/main.cpp @@ -1,4 +1,3 @@ -#include "main.hpp" #include "netfilter/core.hpp" #include "filecheck.hpp" @@ -10,8 +9,7 @@ namespace global { - SourceSDK::FactoryLoader engine_loader( "engine" ); - IServer *server = nullptr; + static IServer *server = nullptr; LUA_FUNCTION_STATIC( GetClientCount ) { @@ -27,11 +25,11 @@ namespace global LUA->CreateTable( ); - LUA->PushString( "serversecure 1.5.37" ); + LUA->PushString( "serversecure 1.5.38" ); LUA->SetField( -2, "Version" ); // version num follows LuaJIT style, xxyyzz - LUA->PushNumber( 10537 ); + LUA->PushNumber( 10538 ); LUA->SetField( -2, "VersionNum" ); LUA->PushCFunction( GetClientCount ); @@ -61,8 +59,8 @@ GMOD_MODULE_OPEN( ) GMOD_MODULE_CLOSE( ) { - filecheck::Deinitialize( LUA ); - netfilter::Deinitialize( LUA ); + filecheck::Deinitialize( ); + netfilter::Deinitialize( ); global::Deinitialize( LUA ); return 0; } diff --git a/source/netfilter/client.cpp b/source/netfilter/client.cpp index 9ac719d..2c142c7 100644 --- a/source/netfilter/client.cpp +++ b/source/netfilter/client.cpp @@ -1,6 +1,6 @@ #include "client.hpp" #include "clientmanager.hpp" -#include "main.hpp" +#include "debug.hpp" namespace netfilter { diff --git a/source/netfilter/clientmanager.cpp b/source/netfilter/clientmanager.cpp index 623e528..9dfec85 100644 --- a/source/netfilter/clientmanager.cpp +++ b/source/netfilter/clientmanager.cpp @@ -1,5 +1,5 @@ #include "clientmanager.hpp" -#include "main.hpp" +#include "debug.hpp" namespace netfilter { diff --git a/source/netfilter/core.cpp b/source/netfilter/core.cpp index f01ec5f..4e97632 100644 --- a/source/netfilter/core.cpp +++ b/source/netfilter/core.cpp @@ -1,11 +1,12 @@ #include "core.hpp" #include "clientmanager.hpp" -#include "main.hpp" +#include "debug.hpp" #include "baseserver.h" #include #include #include +#include #include #include @@ -28,6 +29,8 @@ #include #include #include +#include +#include #if defined SYSTEM_WINDOWS @@ -38,9 +41,6 @@ #include #include -#include -#include - typedef int32_t ssize_t; typedef int32_t recvlen_t; @@ -54,9 +54,6 @@ typedef int32_t recvlen_t; #include #include -#include -#include - typedef int32_t SOCKET; typedef size_t recvlen_t; @@ -72,9 +69,6 @@ static const SOCKET INVALID_SOCKET = -1; #include #include -#include -#include - typedef int32_t SOCKET; typedef size_t recvlen_t; @@ -92,687 +86,855 @@ struct netsocket_t namespace netfilter { - struct packet_t + class Core { - packet_t( ) : - address( ), - address_size( sizeof( address ) ) - { } - - sockaddr_in address; - socklen_t address_size; - std::vector buffer; - }; + private: + struct server_tags_t + { + std::string gm; + std::string gmws; + std::string gmc; + std::string loc; + std::string ver; + }; - struct server_tags_t - { - std::string gm; - std::string gmws; - std::string gmc; - std::string loc; - std::string ver; - }; + public: + struct packet_t + { + packet_t( ) : + address( ), + address_size( sizeof( address ) ) + { } - struct reply_info_t - { - std::string game_dir; - std::string game_version; - std::string game_desc; - int32_t max_clients = 0; - int32_t udp_port = 0; - server_tags_t tags; - }; + sockaddr_in address; + socklen_t address_size; + std::vector buffer; + }; - enum class PacketType - { - Invalid = -1, - Good, - Info - }; + Core( const char *game_version ) + { + server = InterfacePointers::Server( ); + if( server == nullptr ) + throw std::runtime_error( "failed to dereference IServer" ); -#if defined SYSTEM_WINDOWS + if( !server_loader.IsValid( ) ) + throw std::runtime_error( "unable to get server factory" ); - static constexpr char operating_system_char = 'w'; + ICvar *icvar = InterfacePointers::Cvar( ); + if( icvar != nullptr ) + { + sv_visiblemaxplayers = icvar->FindVar( "sv_visiblemaxplayers" ); + sv_location = icvar->FindVar( "sv_location" ); + } -#elif defined SYSTEM_POSIX + if( sv_visiblemaxplayers == nullptr ) + ConColorMsg( Color( 255, 255, 0, 255 ), "[ServerSecure] Failed to get \"sv_visiblemaxplayers\" convar!\n" ); - static constexpr char operating_system_char = 'l'; + if( sv_location == nullptr ) + ConColorMsg( Color( 255, 255, 0, 255 ), "[ServerSecure] Failed to get \"sv_location\" convar!\n" ); -#elif defined SYSTEM_MACOSX + gamedll = InterfacePointers::ServerGameDLL( ); + if( gamedll == nullptr ) + throw std::runtime_error( "failed to load required IServerGameDLL interface" ); - static constexpr char operating_system_char = 'm'; + engine_server = InterfacePointers::VEngineServer( ); + if( engine_server == nullptr ) + throw std::runtime_error( "failed to load required IVEngineServer interface" ); -#endif + filesystem = InterfacePointers::FileSystem( ); + if( filesystem == nullptr ) + throw std::runtime_error( "failed to initialize IFileSystem" ); - static CSteamGameServerAPIContext gameserver_context; - static bool gameserver_context_initialized = false; + const FunctionPointers::GMOD_GetNetSocket_t GetNetSocket = FunctionPointers::GMOD_GetNetSocket( ); + if( GetNetSocket != nullptr ) + { + const netsocket_t *net_socket = GetNetSocket( 1 ); + if( net_socket != nullptr ) + game_socket = net_socket->hUDP; + } - static SourceSDK::FactoryLoader icvar_loader( "vstdlib" ); - static ConVar *sv_visiblemaxplayers = nullptr; - static ConVar *sv_location = nullptr; + if( game_socket == INVALID_SOCKET ) + throw std::runtime_error( "got an invalid server socket" ); - static SourceSDK::ModuleLoader dedicated_loader( "dedicated" ); - static SourceSDK::FactoryLoader server_loader( "server" ); + if( !recvfrom_hook.Enable( ) ) + throw std::runtime_error( "failed to detour recvfrom" ); - static ssize_t SERVERSECURE_CALLING_CONVENTION recvfrom_detour( - SOCKET s, - void *buf, - recvlen_t buflen, - int32_t flags, - sockaddr *from, - socklen_t *fromlen - ); - typedef decltype( recvfrom_detour ) *recvfrom_t; + threaded_socket_execute = true; + threaded_socket_handle = CreateSimpleThread( PacketReceiverThread, this ); + if( threaded_socket_handle == nullptr ) + throw std::runtime_error( "unable to create thread" ); -#ifdef PLATFORM_WINDOWS + BuildStaticReplyInfo( game_version ); + } - static Detouring::Hook recvfrom_hook( "ws2_32", "recvfrom", reinterpret_cast( recvfrom_detour ) ); + ~Core( ) + { + if( threaded_socket_handle != nullptr ) + { + threaded_socket_execute = false; + ThreadJoin( threaded_socket_handle ); + ReleaseThreadHandle( threaded_socket_handle ); + threaded_socket_handle = nullptr; + } -#else + recvfrom_hook.Disable( ); + } - static Detouring::Hook recvfrom_hook( "recvfrom", reinterpret_cast( recvfrom_detour ) ); + void BuildStaticReplyInfo( const char *game_version ) + { + reply_info.game_desc = gamedll->GetGameDescription( ); -#endif + { + reply_info.game_dir.resize( 256 ); + engine_server->GetGameDir( &reply_info.game_dir[0], static_cast( reply_info.game_dir.size( ) ) ); + reply_info.game_dir.resize( std::strlen( reply_info.game_dir.c_str( ) ) ); - static SOCKET game_socket = INVALID_SOCKET; + size_t pos = reply_info.game_dir.find_last_of( "\\/" ); + if( pos != reply_info.game_dir.npos ) + reply_info.game_dir.erase( 0, pos + 1 ); + } - static bool packet_validation_enabled = false; + reply_info.max_clients = server->GetMaxClients( ); - static bool firewall_whitelist_enabled = false; - static std::unordered_set firewall_whitelist; + reply_info.udp_port = server->GetUDPPort( ); - static bool firewall_blacklist_enabled = false; - static std::unordered_set firewall_blacklist; + { + const IGamemodeSystem::Information &gamemode = + static_cast( filesystem )->Gamemodes( )->Active( ); + + if( !gamemode.name.empty( ) ) + reply_info.tags.gm = gamemode.name; + else + reply_info.tags.gm.clear( ); + + if( gamemode.workshopid != 0 ) + reply_info.tags.gmws = std::to_string( gamemode.workshopid ); + else + reply_info.tags.gmws.clear( ); + + if( !gamemode.category.empty( ) ) + reply_info.tags.gmc = gamemode.category; + else + reply_info.tags.gmc.clear( ); + + if( game_version != nullptr ) + reply_info.tags.ver = game_version; + } - static constexpr size_t threaded_socket_max_buffer = 8192; - static constexpr size_t threaded_socket_max_queue = 1000; - static std::atomic_bool threaded_socket_execute( true ); - static ThreadHandle_t threaded_socket_handle = nullptr; - static std::queue threaded_socket_queue; - static CThreadFastMutex threaded_socket_mutex; + { + FileHandle_t file = filesystem->Open( "steam.inf", "r", "GAME" ); + if( file == nullptr ) + { + reply_info.game_version = default_game_version; + _DebugWarning( "[ServerSecure] Error opening steam.inf\n" ); + return; + } - static constexpr char default_game_version[] = "2020.10.14"; - static constexpr uint8_t default_proto_version = 17; - static bool info_cache_enabled = false; - static reply_info_t reply_info; - static char info_cache_buffer[1024] = { 0 }; - static bf_write info_cache_packet( info_cache_buffer, sizeof( info_cache_buffer ) ); - static uint32_t info_cache_last_update = 0; - static uint32_t info_cache_time = 5; + char buff[256] = { 0 }; + bool failed = filesystem->ReadLine( buff, sizeof( buff ), file ) == nullptr; + filesystem->Close( file ); + if( failed ) + { + reply_info.game_version = default_game_version; + _DebugWarning( "[ServerSecure] Failed reading steam.inf\n" ); + return; + } - static ClientManager client_manager; + reply_info.game_version = &buff[13]; - static constexpr size_t packet_sampling_max_queue = 50; - static bool packet_sampling_enabled = false; - static std::queue packet_sampling_queue; - static CThreadFastMutex packet_sampling_mutex; + size_t pos = reply_info.game_version.find_first_of( "\r\n" ); + if( pos != reply_info.game_version.npos ) + reply_info.game_version.erase( pos ); + } + } - static IServerGameDLL *gamedll = nullptr; - static IVEngineServer *engine_server = nullptr; - static IFileSystem *filesystem = nullptr; + static std::string ConcatenateTags( const server_tags_t &tags ) + { + std::string strtags; - // max size needed to contain a steam authentication key (both server and client) - static constexpr size_t STEAM_KEYSIZE = 2048; + if( !tags.gm.empty( ) ) + { + strtags += "gm:"; + strtags += tags.gm; + } - static constexpr int32_t PROTOCOL_AUTHCERTIFICATE = 0x01; // Connection from client is using a WON authenticated certificate - static constexpr int32_t PROTOCOL_HASHEDCDKEY = 0x02; // Connection from client is using hashed CD key because WON comm. channel was unreachable - static constexpr int32_t PROTOCOL_STEAM = 0x03; // Steam certificates - static constexpr int32_t PROTOCOL_LASTVALID = 0x03; // Last valid protocol + if( !tags.gmws.empty( ) ) + { + strtags += strtags.empty( ) ? "gmws:" : " gmws:"; + strtags += tags.gmws; + } - static constexpr int32_t MAX_RANDOM_RANGE = 0x7FFFFFFFUL; + if( !tags.gmc.empty( ) ) + { + strtags += strtags.empty( ) ? "gmc:" : " gmc:"; + strtags += tags.gmc; + } - inline const char *IPToString( const in_addr &addr ) - { - static char buffer[16] = { }; - const char *str = - inet_ntop( AF_INET, const_cast( &addr ), buffer, sizeof( buffer ) ); - if( str == nullptr ) - return "unknown"; + if( !tags.loc.empty( ) ) + { + strtags += strtags.empty( ) ? "loc:" : " loc:"; + strtags += tags.loc; + } - return str; - } + if( !tags.ver.empty( ) ) + { + strtags += strtags.empty( ) ? "ver:" : " ver:"; + strtags += tags.ver; + } - static void BuildStaticReplyInfo( const char *game_version ) - { - reply_info.game_desc = gamedll->GetGameDescription( ); + return strtags; + } + void BuildReplyInfo( ) { - reply_info.game_dir.resize( 256 ); - engine_server->GetGameDir( &reply_info.game_dir[0], static_cast( reply_info.game_dir.size( ) ) ); - reply_info.game_dir.resize( std::strlen( reply_info.game_dir.c_str( ) ) ); + const char *server_name = server->GetName( ); - size_t pos = reply_info.game_dir.find_last_of( "\\/" ); - if( pos != reply_info.game_dir.npos ) - reply_info.game_dir.erase( 0, pos + 1 ); - } + const char *map_name = server->GetMapName( ); - reply_info.max_clients = global::server->GetMaxClients( ); + const char *game_dir = reply_info.game_dir.c_str( ); - reply_info.udp_port = global::server->GetUDPPort( ); + const char *game_desc = reply_info.game_desc.c_str( ); - { - const IGamemodeSystem::Information &gamemode = - static_cast( filesystem )->Gamemodes( )->Active( ); + const int32_t appid = engine_server->GetAppID( ); - if( !gamemode.name.empty( ) ) - reply_info.tags.gm = gamemode.name; - else - reply_info.tags.gm.clear( ); + const int32_t num_clients = server->GetNumClients( ); - if( gamemode.workshopid != 0 ) - reply_info.tags.gmws = std::to_string( gamemode.workshopid ); - else - reply_info.tags.gmws.clear( ); + int32_t max_players = + sv_visiblemaxplayers != nullptr ? sv_visiblemaxplayers->GetInt( ) : -1; + if( max_players <= 0 || max_players > reply_info.max_clients ) + max_players = reply_info.max_clients; - if( !gamemode.category.empty( ) ) - reply_info.tags.gmc = gamemode.category; - else - reply_info.tags.gmc.clear( ); + const int32_t num_fake_clients = server->GetNumFakeClients( ); - if( game_version != nullptr ) - reply_info.tags.ver = game_version; - } + const bool has_password = server->GetPassword( ) != nullptr; - { - FileHandle_t file = filesystem->Open( "steam.inf", "r", "GAME" ); - if( file == nullptr ) - { - reply_info.game_version = default_game_version; - _DebugWarning( "[ServerSecure] Error opening steam.inf\n" ); - return; - } + if( gameserver == nullptr ) + gameserver = SteamGameServer( ); - char buff[256] = { 0 }; - bool failed = filesystem->ReadLine( buff, sizeof( buff ), file ) == nullptr; - filesystem->Close( file ); - if( failed ) - { - reply_info.game_version = default_game_version; - _DebugWarning( "[ServerSecure] Failed reading steam.inf\n" ); - return; - } + bool vac_secure = false; + if( gameserver != nullptr ) + vac_secure = gameserver->BSecure( ); + + const char *game_version = reply_info.game_version.c_str( ); - reply_info.game_version = &buff[13]; + const int32_t udp_port = reply_info.udp_port; - size_t pos = reply_info.game_version.find_first_of( "\r\n" ); - if( pos != reply_info.game_version.npos ) - reply_info.game_version.erase( pos ); + const CSteamID *sid = engine_server->GetGameServerSteamID( ); + const uint64_t steamid = sid != nullptr ? sid->ConvertToUint64( ) : 0; + + if( sv_location != nullptr ) + reply_info.tags.loc = sv_location->GetString( ); + else + reply_info.tags.loc.clear( ); + + const std::string tags = ConcatenateTags( reply_info.tags ); + const bool has_tags = !tags.empty( ); + + info_cache_packet.Reset( ); + + info_cache_packet.WriteLong( -1 ); // connectionless packet header + info_cache_packet.WriteByte( 'I' ); // packet type is always 'I' + info_cache_packet.WriteByte( default_proto_version ); + info_cache_packet.WriteString( server_name ); + info_cache_packet.WriteString( map_name ); + info_cache_packet.WriteString( game_dir ); + info_cache_packet.WriteString( game_desc ); + info_cache_packet.WriteShort( appid ); + info_cache_packet.WriteByte( num_clients ); + info_cache_packet.WriteByte( max_players ); + info_cache_packet.WriteByte( num_fake_clients ); + info_cache_packet.WriteByte( 'd' ); // dedicated server identifier + info_cache_packet.WriteByte( operating_system_char ); + info_cache_packet.WriteByte( has_password ? 1 : 0 ); + info_cache_packet.WriteByte( vac_secure ); + info_cache_packet.WriteString( game_version ); + // 0x80 - port number is present + // 0x10 - server steamid is present + // 0x20 - tags are present + // 0x01 - game long appid is present + info_cache_packet.WriteByte( 0x80 | 0x10 | ( has_tags ? 0x20 : 0x00 ) | 0x01 ); + info_cache_packet.WriteShort( udp_port ); + info_cache_packet.WriteLongLong( steamid ); + if( has_tags ) + info_cache_packet.WriteString( tags.c_str( ) ); + info_cache_packet.WriteLongLong( appid ); } - } - static std::string ConcatenateTags( const server_tags_t &tags ) - { - std::string strtags; + void SetFirewallWhitelistState( const bool enabled ) + { + firewall_whitelist_enabled = enabled; + } - if( !tags.gm.empty( ) ) + // Whitelisted IPs bytes need to be in network order (big endian) + void AddWhitelistIP( const uint32_t address ) { - strtags += "gm:"; - strtags += tags.gm; + firewall_whitelist.insert( address ); } - if( !tags.gmws.empty( ) ) + void RemoveWhitelistIP( const uint32_t address ) { - strtags += strtags.empty( ) ? "gmws:" : " gmws:"; - strtags += tags.gmws; + firewall_whitelist.erase( address ); } - if( !tags.gmc.empty( ) ) + void ResetWhitelist( ) { - strtags += strtags.empty( ) ? "gmc:" : " gmc:"; - strtags += tags.gmc; + std::unordered_set( ).swap( firewall_whitelist ); } - if( !tags.loc.empty( ) ) + void SetFirewallBlacklistState( const bool enabled ) { - strtags += strtags.empty( ) ? "loc:" : " loc:"; - strtags += tags.loc; + firewall_blacklist_enabled = enabled; } - if( !tags.ver.empty( ) ) + // Blacklisted IPs bytes need to be in network order (big endian) + void AddBlacklistIP( const uint32_t address ) { - strtags += strtags.empty( ) ? "ver:" : " ver:"; - strtags += tags.ver; + firewall_blacklist.insert( address ); } - return strtags; - } + void RemoveBlacklistIP( const uint32_t address ) + { + firewall_blacklist.erase( address ); + } - static void BuildReplyInfo( ) - { - const char *server_name = global::server->GetName( ); - - const char *map_name = global::server->GetMapName( ); + void ResetBlacklist( ) + { + std::unordered_set( ).swap( firewall_blacklist ); + } - const char *game_dir = reply_info.game_dir.c_str( ); + void SetPacketValidationState( const bool enabled ) + { + packet_validation_enabled = enabled; + } - const char *game_desc = reply_info.game_desc.c_str( ); + void SetInfoCacheState( const bool enabled ) + { + info_cache_enabled = enabled; + } - const int32_t appid = engine_server->GetAppID( ); + void SetInfoCacheTime( const uint32_t time ) + { + info_cache_time = time; + } - const int32_t num_clients = global::server->GetNumClients( ); + bool PopPacketFromSamplingQueue( packet_t &p ) + { + AUTO_LOCK( packet_sampling_mutex ); - int32_t max_players = - sv_visiblemaxplayers != nullptr ? sv_visiblemaxplayers->GetInt( ) : -1; - if( max_players <= 0 || max_players > reply_info.max_clients ) - max_players = reply_info.max_clients; + if( packet_sampling_queue.empty( ) ) + return false; - const int32_t num_fake_clients = global::server->GetNumFakeClients( ); + p = std::move( packet_sampling_queue.front( ) ); + packet_sampling_queue.pop( ); + return true; + } - const bool has_password = global::server->GetPassword( ) != nullptr; + void SetPacketSamplingState( bool enabled ) + { + packet_sampling_enabled = enabled; - if( !gameserver_context_initialized ) - gameserver_context_initialized = gameserver_context.Init( ); + if( !enabled ) + { + AUTO_LOCK( packet_sampling_mutex ); + std::queue( ).swap( packet_sampling_queue ); + } + } - bool vac_secure = false; - if( gameserver_context_initialized ) + ClientManager &GetClientManager( ) { - ISteamGameServer *steamGS = gameserver_context.SteamGameServer( ); - if( steamGS != nullptr ) - vac_secure = steamGS->BSecure( ); + return client_manager; } - const char *game_version = reply_info.game_version.c_str( ); - - const int32_t udp_port = reply_info.udp_port; - - const CSteamID *sid = engine_server->GetGameServerSteamID( ); - const uint64_t steamid = sid != nullptr ? sid->ConvertToUint64( ) : 0; - - if( sv_location != nullptr ) - reply_info.tags.loc = sv_location->GetString( ); - else - reply_info.tags.loc.clear( ); - - const std::string tags = ConcatenateTags( reply_info.tags ); - const bool has_tags = !tags.empty( ); - - info_cache_packet.Reset( ); - - info_cache_packet.WriteLong( -1 ); // connectionless packet header - info_cache_packet.WriteByte( 'I' ); // packet type is always 'I' - info_cache_packet.WriteByte( default_proto_version ); - info_cache_packet.WriteString( server_name ); - info_cache_packet.WriteString( map_name ); - info_cache_packet.WriteString( game_dir ); - info_cache_packet.WriteString( game_desc ); - info_cache_packet.WriteShort( appid ); - info_cache_packet.WriteByte( num_clients ); - info_cache_packet.WriteByte( max_players ); - info_cache_packet.WriteByte( num_fake_clients ); - info_cache_packet.WriteByte( 'd' ); // dedicated server identifier - info_cache_packet.WriteByte( operating_system_char ); - info_cache_packet.WriteByte( has_password ? 1 : 0 ); - info_cache_packet.WriteByte( vac_secure ); - info_cache_packet.WriteString( game_version ); - // 0x80 - port number is present - // 0x10 - server steamid is present - // 0x20 - tags are present - // 0x01 - game long appid is present - info_cache_packet.WriteByte( 0x80 | 0x10 | ( has_tags ? 0x20 : 0x00 ) | 0x01 ); - info_cache_packet.WriteShort( udp_port ); - info_cache_packet.WriteLongLong( steamid ); - if( has_tags ) - info_cache_packet.WriteString( tags.c_str( ) ); - info_cache_packet.WriteLongLong( appid ); - } + static std::unique_ptr Singleton; - inline PacketType SendInfoCache( const sockaddr_in &from, uint32_t time ) - { - if( time - info_cache_last_update >= info_cache_time ) + private: + struct reply_info_t { - BuildReplyInfo( ); - info_cache_last_update = time; - } - - sendto( - game_socket, - reinterpret_cast( info_cache_packet.GetData( ) ), - info_cache_packet.GetNumBytesWritten( ), - 0, - reinterpret_cast( &from ), - sizeof( from ) + std::string game_dir; + std::string game_version; + std::string game_desc; + int32_t max_clients = 0; + int32_t udp_port = 0; + server_tags_t tags; + }; + + enum class PacketType + { + Invalid = -1, + Good, + Info + }; + + typedef ssize_t ( SERVERSECURE_CALLING_CONVENTION *recvfrom_t )( + SOCKET s, + void *buf, + recvlen_t buflen, + int32_t flags, + sockaddr *from, + socklen_t *fromlen ); - _DebugWarning( "[ServerSecure] Handled %s info request using cache\n", IPToString( from.sin_addr ) ); +#if defined SYSTEM_WINDOWS - return PacketType::Invalid; // we've handled it - } + static constexpr char operating_system_char = 'w'; - inline PacketType HandleInfoQuery( const sockaddr_in &from ) - { - const uint32_t time = static_cast( Plat_FloatTime( ) ); - if( !client_manager.CheckIPRate( from.sin_addr.s_addr, time ) ) - { - _DebugWarning( "[ServerSecure] Client %s hit rate limit\n", IPToString( from.sin_addr ) ); - return PacketType::Invalid; - } +#elif defined SYSTEM_POSIX - if( info_cache_enabled ) - return SendInfoCache( from, time ); + static constexpr char operating_system_char = 'l'; - return PacketType::Good; - } +#elif defined SYSTEM_MACOSX - static PacketType ClassifyPacket( const uint8_t *data, int32_t len, const sockaddr_in &from ) - { - if( len == 0 ) + static constexpr char operating_system_char = 'm'; + +#endif + + static constexpr size_t threaded_socket_max_buffer = 8192; + static constexpr size_t threaded_socket_max_queue = 1000; + + static constexpr char default_game_version[11] = "2020.10.14"; + static constexpr uint8_t default_proto_version = 17; + + static constexpr size_t packet_sampling_max_queue = 50; + + // max size needed to contain a steam authentication key (both server and client) + static constexpr size_t STEAM_KEYSIZE = 2048; + + static constexpr int32_t PROTOCOL_AUTHCERTIFICATE = 0x01; // Connection from client is using a WON authenticated certificate + static constexpr int32_t PROTOCOL_HASHEDCDKEY = 0x02; // Connection from client is using hashed CD key because WON comm. channel was unreachable + static constexpr int32_t PROTOCOL_STEAM = 0x03; // Steam certificates + static constexpr int32_t PROTOCOL_LASTVALID = 0x03; // Last valid protocol + + static constexpr int32_t MAX_RANDOM_RANGE = 0x7FFFFFFFUL; + + IServer *server = nullptr; + + ISteamGameServer *gameserver = nullptr; + + SourceSDK::FactoryLoader icvar_loader = SourceSDK::FactoryLoader( "vstdlib" ); + ConVar *sv_visiblemaxplayers = nullptr; + ConVar *sv_location = nullptr; + + SourceSDK::ModuleLoader dedicated_loader = SourceSDK::ModuleLoader( "dedicated" ); + SourceSDK::FactoryLoader server_loader = SourceSDK::FactoryLoader( "server" ); + + #ifdef PLATFORM_WINDOWS + + Detouring::Hook recvfrom_hook = Detouring::Hook( "ws2_32", "recvfrom", reinterpret_cast( recvfrom_detour ) ); + + #else + + Detouring::Hook recvfrom_hook = Detouring::Hook( "recvfrom", reinterpret_cast( recvfrom_detour ) ); + + #endif + + SOCKET game_socket = INVALID_SOCKET; + + bool packet_validation_enabled = false; + + bool firewall_whitelist_enabled = false; + std::unordered_set firewall_whitelist; + + bool firewall_blacklist_enabled = false; + std::unordered_set firewall_blacklist; + + bool threaded_socket_execute = true; + ThreadHandle_t threaded_socket_handle = nullptr; + std::queue threaded_socket_queue; + CThreadFastMutex threaded_socket_mutex; + + bool info_cache_enabled = false; + reply_info_t reply_info; + char info_cache_buffer[1024] = { 0 }; + bf_write info_cache_packet = bf_write( info_cache_buffer, sizeof( info_cache_buffer ) ); + uint32_t info_cache_last_update = 0; + uint32_t info_cache_time = 5; + + ClientManager client_manager; + + bool packet_sampling_enabled = false; + std::queue packet_sampling_queue; + CThreadFastMutex packet_sampling_mutex; + + IServerGameDLL *gamedll = nullptr; + IVEngineServer *engine_server = nullptr; + IFileSystem *filesystem = nullptr; + + inline const char *IPToString( const in_addr &addr ) { - _DebugWarning( - "[ServerSecure] Bad OOB! len: %d from %s\n", - len, - IPToString( from.sin_addr ) - ); - return PacketType::Invalid; - } + static char buffer[16] = { }; + const char *str = + inet_ntop( AF_INET, const_cast( &addr ), buffer, sizeof( buffer ) ); + if( str == nullptr ) + return "unknown"; - if( len < 5 ) - return PacketType::Good; + return str; + } - bf_read packet( data, len ); - const int32_t channel = static_cast( packet.ReadLong( ) ); - if( channel == -2 ) + PacketType SendInfoCache( const sockaddr_in &from, uint32_t time ) { - _DebugWarning( - "[ServerSecure] Bad OOB! len: %d, channel: 0x%X from %s\n", - len, - channel, - IPToString( from.sin_addr ) + if( time - info_cache_last_update >= info_cache_time ) + { + BuildReplyInfo( ); + info_cache_last_update = time; + } + + sendto( + game_socket, + reinterpret_cast( info_cache_packet.GetData( ) ), + info_cache_packet.GetNumBytesWritten( ), + 0, + reinterpret_cast( &from ), + sizeof( from ) ); - return PacketType::Invalid; + + _DebugWarning( "[ServerSecure] Handled %s info request using cache\n", IPToString( from.sin_addr ) ); + + return PacketType::Invalid; // we've handled it } - if( channel != -1 ) + PacketType HandleInfoQuery( const sockaddr_in &from ) + { + const uint32_t time = static_cast( Plat_FloatTime( ) ); + if( !client_manager.CheckIPRate( from.sin_addr.s_addr, time ) ) + { + _DebugWarning( "[ServerSecure] Client %s hit rate limit\n", IPToString( from.sin_addr ) ); + return PacketType::Invalid; + } + + if( info_cache_enabled ) + return SendInfoCache( from, time ); + return PacketType::Good; + } - const uint8_t type = static_cast( packet.ReadByte( ) ); - if( packet_validation_enabled ) + PacketType ClassifyPacket( const uint8_t *data, int32_t len, const sockaddr_in &from ) { - switch( type ) + if( len == 0 ) { - case 'W': // server challenge request - case 's': // master server challenge - if( len > 100 ) - { - _DebugWarning( - "[ServerSecure] Bad OOB! len: %d, channel: 0x%X, type: %c from %s\n", - len, - channel, - type, - IPToString( from.sin_addr ) - ); - return PacketType::Invalid; - } - - if( len >= 18 && strncmp( reinterpret_cast( data + 5 ), "statusResponse", 14 ) == 0 ) - { - _DebugWarning( - "[ServerSecure] Bad OOB! len: %d, channel: 0x%X, type: %c from %s\n", - len, - channel, - type, - IPToString( from.sin_addr ) - ); - return PacketType::Invalid; - } + _DebugWarning( + "[ServerSecure] Bad OOB! len: %d from %s\n", + len, + IPToString( from.sin_addr ) + ); + return PacketType::Invalid; + } + if( len < 5 ) return PacketType::Good; - case 'T': // server info request (A2S_INFO) - if( ( len != 25 && len != 1200 ) || strncmp( reinterpret_cast( data + 5 ), "Source Engine Query", 19 ) != 0 ) - { - _DebugWarning( - "[ServerSecure] Bad OOB! len: %d, channel: 0x%X, type: %c from %s\n", - len, - channel, - type, - IPToString( from.sin_addr ) - ); - return PacketType::Invalid; - } + bf_read packet( data, len ); + const int32_t channel = static_cast( packet.ReadLong( ) ); + if( channel == -2 ) + { + _DebugWarning( + "[ServerSecure] Bad OOB! len: %d, channel: 0x%X from %s\n", + len, + channel, + IPToString( from.sin_addr ) + ); + return PacketType::Invalid; + } - return PacketType::Info; + if( channel != -1 ) + return PacketType::Good; - case 'U': // player info request (A2S_PLAYER) - case 'V': // rules request (A2S_RULES) - if( len != 9 && len != 1200 ) + const uint8_t type = static_cast( packet.ReadByte( ) ); + if( packet_validation_enabled ) + { + switch( type ) { - _DebugWarning( - "[ServerSecure] Bad OOB! len: %d, channel: 0x%X, type: %c from %s\n", + case 'W': // server challenge request + case 's': // master server challenge + if( len > 100 ) + { + _DebugWarning( + "[ServerSecure] Bad OOB! len: %d, channel: 0x%X, type: %c from %s\n", + len, + channel, + type, + IPToString( from.sin_addr ) + ); + return PacketType::Invalid; + } + + if( len >= 18 && strncmp( reinterpret_cast( data + 5 ), "statusResponse", 14 ) == 0 ) + { + _DebugWarning( + "[ServerSecure] Bad OOB! len: %d, channel: 0x%X, type: %c from %s\n", + len, + channel, + type, + IPToString( from.sin_addr ) + ); + return PacketType::Invalid; + } + + return PacketType::Good; + + case 'T': // server info request (A2S_INFO) + if( ( len != 25 && len != 1200 ) || strncmp( reinterpret_cast( data + 5 ), "Source Engine Query", 19 ) != 0 ) + { + _DebugWarning( + "[ServerSecure] Bad OOB! len: %d, channel: 0x%X, type: %c from %s\n", + len, + channel, + type, + IPToString( from.sin_addr ) + ); + return PacketType::Invalid; + } + + return PacketType::Info; + + case 'U': // player info request (A2S_PLAYER) + case 'V': // rules request (A2S_RULES) + if( len != 9 && len != 1200 ) + { + _DebugWarning( + "[ServerSecure] Bad OOB! len: %d, channel: 0x%X, type: %c from %s\n", + len, + channel, + type, + IPToString( from.sin_addr ) + ); + return PacketType::Invalid; + } + + return PacketType::Good; + + case 'q': // connection handshake init + case 'k': // steam auth packet + _DebugMsg( + "[ServerSecure] Good OOB! len: %d, channel: 0x%X, type: %c from %s\n", len, channel, type, IPToString( from.sin_addr ) ); - return PacketType::Invalid; + return PacketType::Good; } - return PacketType::Good; - - case 'q': // connection handshake init - case 'k': // steam auth packet - _DebugMsg( - "[ServerSecure] Good OOB! len: %d, channel: 0x%X, type: %c from %s\n", + _DebugWarning( + "[ServerSecure] Bad OOB! len: %d, channel: 0x%X, type: %c from %s\n", len, channel, type, IPToString( from.sin_addr ) ); - return PacketType::Good; + return PacketType::Invalid; } - _DebugWarning( - "[ServerSecure] Bad OOB! len: %d, channel: 0x%X, type: %c from %s\n", - len, - channel, - type, - IPToString( from.sin_addr ) - ); - return PacketType::Invalid; + return type == 'T' ? PacketType::Info : PacketType::Good; } - return type == 'T' ? PacketType::Info : PacketType::Good; - } + bool IsAddressAllowed( const sockaddr_in &addr ) + { + return + ( + !firewall_whitelist_enabled || + firewall_whitelist.find( addr.sin_addr.s_addr ) != firewall_whitelist.end( ) + ) && + ( + !firewall_blacklist_enabled || + firewall_blacklist.find( addr.sin_addr.s_addr ) == firewall_blacklist.end( ) + ); + } - inline bool IsAddressAllowed( const sockaddr_in &addr ) - { - return - ( - !firewall_whitelist_enabled || - firewall_whitelist.find( addr.sin_addr.s_addr ) != firewall_whitelist.end( ) - ) && - ( - !firewall_blacklist_enabled || - firewall_blacklist.find( addr.sin_addr.s_addr ) == firewall_blacklist.end( ) - ); - } + int32_t HandleNetError( int32_t value ) + { + if( value == -1 ) - inline int32_t HandleNetError( int32_t value ) - { - if( value == -1 ) + #if defined SYSTEM_WINDOWS -#if defined SYSTEM_WINDOWS + WSASetLastError( WSAEWOULDBLOCK ); - WSASetLastError( WSAEWOULDBLOCK ); + #elif defined SYSTEM_POSIX -#elif defined SYSTEM_POSIX + errno = EWOULDBLOCK; - errno = EWOULDBLOCK; + #endif -#endif + return value; + } - return value; - } + bool IsPacketQueueFull( ) + { + AUTO_LOCK( threaded_socket_mutex ); + return threaded_socket_queue.size( ) >= threaded_socket_max_queue; + } - inline bool IsPacketQueueFull( ) - { - AUTO_LOCK( threaded_socket_mutex ); - return threaded_socket_queue.size( ) >= threaded_socket_max_queue; - } + bool PopPacketFromQueue( packet_t &p ) + { + AUTO_LOCK( threaded_socket_mutex ); - inline bool PopPacketFromQueue( packet_t &p ) - { - AUTO_LOCK( threaded_socket_mutex ); + if( threaded_socket_queue.empty( ) ) + return false; - if( threaded_socket_queue.empty( ) ) - return false; + p = std::move( threaded_socket_queue.front( ) ); + threaded_socket_queue.pop( ); + return true; + } - p = std::move( threaded_socket_queue.front( ) ); - threaded_socket_queue.pop( ); - return true; - } + void PushPacketToQueue( packet_t &&p ) + { + AUTO_LOCK( threaded_socket_mutex ); + threaded_socket_queue.emplace( std::move( p ) ); + } - inline void PushPacketToQueue( packet_t &&p ) - { - AUTO_LOCK( threaded_socket_mutex ); - threaded_socket_queue.emplace( std::move( p ) ); - } + void PushPacketToSamplingQueue( packet_t &&p ) + { + AUTO_LOCK( packet_sampling_mutex ); - inline void PushPacketToSamplingQueue( packet_t &&p ) - { - AUTO_LOCK( packet_sampling_mutex ); + if( packet_sampling_queue.size( ) >= packet_sampling_max_queue ) + packet_sampling_queue.pop( ); - if( packet_sampling_queue.size( ) >= packet_sampling_max_queue ) - packet_sampling_queue.pop( ); + packet_sampling_queue.emplace( std::move( p ) ); + } - packet_sampling_queue.emplace( std::move( p ) ); - } + ssize_t ReceiveAndAnalyzePacket( + SOCKET s, + void *buf, + recvlen_t buflen, + int32_t flags, + sockaddr *from, + socklen_t *fromlen + ) + { + auto trampoline = recvfrom_hook.GetTrampoline( ); + if( trampoline == nullptr ) + return -1; - inline bool PopPacketFromSamplingQueue( packet_t &p ) - { - AUTO_LOCK( packet_sampling_mutex ); + const ssize_t len = trampoline( s, buf, buflen, flags, from, fromlen ); + _DebugWarning( "[ServerSecure] Called recvfrom on socket %d and received %d bytes\n", s, len ); + if( len == -1 ) + return -1; - if( packet_sampling_queue.empty( ) ) - return false; + const uint8_t *buffer = reinterpret_cast( buf ); + if( packet_sampling_enabled ) + { + packet_t p; + std::memcpy( &p.address, from, *fromlen ); + p.address_size = *fromlen; + p.buffer.assign( buffer, buffer + len ); - p = std::move( packet_sampling_queue.front( ) ); - packet_sampling_queue.pop( ); - return true; - } + PushPacketToSamplingQueue( std::move( p ) ); + } - static ssize_t ReceiveAndAnalyzePacket( - SOCKET s, - void *buf, - recvlen_t buflen, - int32_t flags, - sockaddr *from, - socklen_t *fromlen - ) - { - auto trampoline = recvfrom_hook.GetTrampoline( ); - if( trampoline == nullptr ) - return -1; + const sockaddr_in &infrom = *reinterpret_cast( from ); + if( !IsAddressAllowed( infrom ) ) + return -1; - const ssize_t len = trampoline( s, buf, buflen, flags, from, fromlen ); - _DebugWarning( "[ServerSecure] Called recvfrom on socket %d and received %d bytes\n", s, len ); - if( len == -1 ) - return -1; + _DebugWarning( "[ServerSecure] Address %s was allowed\n", IPToString( infrom.sin_addr ) ); - const uint8_t *buffer = reinterpret_cast( buf ); - if( packet_sampling_enabled ) - { - packet_t p; - std::memcpy( &p.address, from, *fromlen ); - p.address_size = *fromlen; - p.buffer.assign( buffer, buffer + len ); + PacketType type = ClassifyPacket( buffer, len, infrom ); + if( type == PacketType::Info ) + type = HandleInfoQuery( infrom ); - PushPacketToSamplingQueue( std::move( p ) ); + return type != PacketType::Invalid ? len : -1; } - const sockaddr_in &infrom = *reinterpret_cast( from ); - if( !IsAddressAllowed( infrom ) ) - return -1; - - _DebugWarning( "[ServerSecure] Address %s was allowed\n", IPToString( infrom.sin_addr ) ); + ssize_t HandleDetour( + SOCKET s, + void *buf, + recvlen_t buflen, + int32_t flags, + sockaddr *from, + socklen_t *fromlen + ) + { + if( s != game_socket ) + { + _DebugWarning( "[ServerSecure] recvfrom detour called with socket %d, passing through\n", s ); + auto trampoline = recvfrom_hook.GetTrampoline( ); + return trampoline != nullptr ? trampoline( s, buf, buflen, flags, from, fromlen ) : -1; + } - PacketType type = ClassifyPacket( buffer, len, infrom ); - if( type == PacketType::Info ) - type = HandleInfoQuery( infrom ); + _DebugWarning( "[ServerSecure] recvfrom detour called with socket %d, detouring\n", s ); - return type != PacketType::Invalid ? len : -1; - } - - static ssize_t SERVERSECURE_CALLING_CONVENTION recvfrom_detour( - SOCKET s, - void *buf, - recvlen_t buflen, - int32_t flags, - sockaddr *from, - socklen_t *fromlen - ) - { - if( s != game_socket ) - { - _DebugWarning( "[ServerSecure] recvfrom detour called with socket %d, passing through\n", s ); - auto trampoline = recvfrom_hook.GetTrampoline( ); - return trampoline != nullptr ? trampoline( s, buf, buflen, flags, from, fromlen ) : -1; - } + packet_t p; + const bool has_packet = PopPacketFromQueue( p ); + if( !has_packet ) + return HandleNetError( -1 ); - _DebugWarning( "[ServerSecure] recvfrom detour called with socket %d, detouring\n", s ); - - packet_t p; - const bool has_packet = PopPacketFromQueue( p ); - if( !has_packet ) - return HandleNetError( -1 ); + const ssize_t len = (std::min)( static_cast( p.buffer.size( ) ), static_cast( buflen ) ); + p.buffer.resize( static_cast( len ) ); + std::copy( p.buffer.begin( ), p.buffer.end( ), static_cast( buf ) ); - const ssize_t len = (std::min)( static_cast( p.buffer.size( ) ), static_cast( buflen ) ); - p.buffer.resize( static_cast( len ) ); - std::copy( p.buffer.begin( ), p.buffer.end( ), static_cast( buf ) ); + const socklen_t addrlen = (std::min)( *fromlen, p.address_size ); + std::memcpy( from, &p.address, static_cast( addrlen ) ); + *fromlen = addrlen; - const socklen_t addrlen = (std::min)( *fromlen, p.address_size ); - std::memcpy( from, &p.address, static_cast( addrlen ) ); - *fromlen = addrlen; + return len; + } - return len; - } + static ssize_t SERVERSECURE_CALLING_CONVENTION recvfrom_detour( + SOCKET s, + void *buf, + recvlen_t buflen, + int32_t flags, + sockaddr *from, + socklen_t *fromlen + ) + { + return Singleton->HandleDetour( s, buf, buflen, flags, from, fromlen ); + } - static uintp PacketReceiverThread( void * ) - { - while( threaded_socket_execute ) + uintp HandleThread( ) { - if( IsPacketQueueFull( ) ) + while( threaded_socket_execute ) { - _DebugWarning( "[ServerSecure] Packet queue is full, sleeping for 100ms\n" ); - ThreadSleep( 100 ); - continue; - } + if( IsPacketQueueFull( ) ) + { + _DebugWarning( "[ServerSecure] Packet queue is full, sleeping for 100ms\n" ); + ThreadSleep( 100 ); + continue; + } - fd_set readables; - FD_ZERO( &readables ); - FD_SET( game_socket, &readables ); - timeval timeout = { 0, 100000 }; - const int32_t res = select( game_socket + 1, &readables, nullptr, nullptr, &timeout ); - if( res == -1 || !FD_ISSET( game_socket, &readables ) ) - continue; + fd_set readables; + FD_ZERO( &readables ); + FD_SET( game_socket, &readables ); + timeval timeout = { 0, 100000 }; + const int32_t res = select( game_socket + 1, &readables, nullptr, nullptr, &timeout ); + if( res == -1 || !FD_ISSET( game_socket, &readables ) ) + continue; + + _DebugWarning( "[ServerSecure] Select passed\n" ); + + packet_t p; + p.buffer.resize( threaded_socket_max_buffer ); + const ssize_t len = ReceiveAndAnalyzePacket( + game_socket, + p.buffer.data( ), + static_cast( threaded_socket_max_buffer ), + 0, + reinterpret_cast( &p.address ), + &p.address_size + ); + if( len == -1 ) + continue; - _DebugWarning( "[ServerSecure] Select passed\n" ); + _DebugWarning( "[ServerSecure] Pushing packet to queue\n" ); - packet_t p; - p.buffer.resize( threaded_socket_max_buffer ); - const ssize_t len = ReceiveAndAnalyzePacket( - game_socket, - p.buffer.data( ), - static_cast( threaded_socket_max_buffer ), - 0, - reinterpret_cast( &p.address ), - &p.address_size - ); - if( len == -1 ) - continue; + p.buffer.resize( static_cast( len ) ); - _DebugWarning( "[ServerSecure] Pushing packet to queue\n" ); + PushPacketToQueue( std::move( p ) ); + } - p.buffer.resize( static_cast( len ) ); + return 0; + } - PushPacketToQueue( std::move( p ) ); + static uintp PacketReceiverThread( void *param ) + { + return static_cast( param )->HandleThread( ); } + }; - return 0; - } + std::unique_ptr Core::Singleton; LUA_FUNCTION_STATIC( EnableFirewallWhitelist ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Bool ); - firewall_whitelist_enabled = LUA->GetBool( 1 ); + Core::Singleton->SetFirewallWhitelistState( LUA->GetBool( 1 ) ); return 0; } @@ -780,27 +942,27 @@ namespace netfilter LUA_FUNCTION_STATIC( AddWhitelistIP ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Number ); - firewall_whitelist.insert( static_cast( LUA->GetNumber( 1 ) ) ); + Core::Singleton->AddWhitelistIP( static_cast( LUA->GetNumber( 1 ) ) ); return 0; } LUA_FUNCTION_STATIC( RemoveWhitelistIP ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Number ); - firewall_whitelist.erase( static_cast( LUA->GetNumber( 1 ) ) ); + Core::Singleton->RemoveWhitelistIP( static_cast( LUA->GetNumber( 1 ) ) ); return 0; } LUA_FUNCTION_STATIC( ResetWhitelist ) { - std::unordered_set( ).swap( firewall_whitelist ); + Core::Singleton->ResetWhitelist( ); return 0; } LUA_FUNCTION_STATIC( EnableFirewallBlacklist ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Bool ); - firewall_blacklist_enabled = LUA->GetBool( 1 ); + Core::Singleton->SetFirewallBlacklistState( LUA->GetBool( 1 ) ); return 0; } @@ -808,76 +970,76 @@ namespace netfilter LUA_FUNCTION_STATIC( AddBlacklistIP ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Number ); - firewall_blacklist.insert( static_cast( LUA->GetNumber( 1 ) ) ); + Core::Singleton->AddBlacklistIP( static_cast( LUA->GetNumber( 1 ) ) ); return 0; } LUA_FUNCTION_STATIC( RemoveBlacklistIP ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Number ); - firewall_blacklist.erase( static_cast( LUA->GetNumber( 1 ) ) ); + Core::Singleton->RemoveBlacklistIP( static_cast( LUA->GetNumber( 1 ) ) ); return 0; } LUA_FUNCTION_STATIC( ResetBlacklist ) { - std::unordered_set( ).swap( firewall_blacklist ); + Core::Singleton->ResetBlacklist( ); return 0; } LUA_FUNCTION_STATIC( EnablePacketValidation ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Bool ); - packet_validation_enabled = LUA->GetBool( 1 ); + Core::Singleton->SetPacketValidationState( LUA->GetBool( 1 ) ); return 0; } LUA_FUNCTION_STATIC( EnableInfoCache ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Bool ); - info_cache_enabled = LUA->GetBool( 1 ); + Core::Singleton->SetInfoCacheState( LUA->GetBool( 1 ) ); return 0; } LUA_FUNCTION_STATIC( SetInfoCacheTime ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Number ); - info_cache_time = static_cast( LUA->GetNumber( 1 ) ); + Core::Singleton->SetInfoCacheTime( static_cast( LUA->GetNumber( 1 ) ) ); return 0; } LUA_FUNCTION_STATIC( RefreshInfoCache ) { - BuildStaticReplyInfo( nullptr ); - BuildReplyInfo( ); + Core::Singleton->BuildStaticReplyInfo( nullptr ); + Core::Singleton->BuildReplyInfo( ); return 0; } LUA_FUNCTION_STATIC( EnableQueryLimiter ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Bool ); - client_manager.SetState( LUA->GetBool( 1 ) ); + Core::Singleton->GetClientManager( ).SetState( LUA->GetBool( 1 ) ); return 0; } LUA_FUNCTION_STATIC( SetMaxQueriesWindow ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Number ); - client_manager.SetMaxQueriesWindow( static_cast( LUA->GetNumber( 1 ) ) ); + Core::Singleton->GetClientManager( ).SetMaxQueriesWindow( static_cast( LUA->GetNumber( 1 ) ) ); return 0; } LUA_FUNCTION_STATIC( SetMaxQueriesPerSecond ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Number ); - client_manager.SetMaxQueriesPerSecond( static_cast( LUA->GetNumber( 1 ) ) ); + Core::Singleton->GetClientManager( ).SetMaxQueriesPerSecond( static_cast( LUA->GetNumber( 1 ) ) ); return 0; } LUA_FUNCTION_STATIC( SetGlobalMaxQueriesPerSecond ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Number ); - client_manager.SetGlobalMaxQueriesPerSecond( + Core::Singleton->GetClientManager( ).SetGlobalMaxQueriesPerSecond( static_cast( LUA->GetNumber( 1 ) ) ); return 0; @@ -886,21 +1048,14 @@ namespace netfilter LUA_FUNCTION_STATIC( EnablePacketSampling ) { LUA->CheckType( 1, GarrysMod::Lua::Type::Bool ); - - packet_sampling_enabled = LUA->GetBool( 1 ); - if( !packet_sampling_enabled ) - { - AUTO_LOCK( packet_sampling_mutex ); - std::queue( ).swap( packet_sampling_queue ); - } - + Core::Singleton->SetPacketSamplingState( LUA->GetBool( 1 ) ); return 0; } LUA_FUNCTION_STATIC( GetSamplePacket ) { - packet_t p; - if( !PopPacketFromSamplingQueue( p ) ) + Core::packet_t p; + if( !Core::Singleton->PopPacketFromSamplingQueue( p ) ) return 0; LUA->PushNumber( p.address.sin_addr.s_addr ); @@ -915,9 +1070,20 @@ namespace netfilter typedef CBaseServer TargetClass; typedef CBaseServerProxy SubstituteClass; - CBaseServerProxy( ) = default; - public: + CBaseServerProxy( CBaseServer *baseserver ) + { + Initialize( baseserver ); + Hook( &CBaseServer::CheckChallengeNr, &CBaseServerProxy::CheckChallengeNr ); + Hook( &CBaseServer::GetChallengeNr, &CBaseServerProxy::GetChallengeNr ); + } + + ~CBaseServerProxy( ) + { + UnHook( &CBaseServer::CheckChallengeNr ); + UnHook( &CBaseServer::GetChallengeNr ); + } + virtual bool CheckChallengeNr( netadr_t &adr, int nChallengeValue ) { // See if the challenge is valid @@ -988,7 +1154,7 @@ namespace netfilter static std::array m_previous_challenge; static std::array m_challenge; - static CBaseServerProxy Singleton; + static std::unique_ptr Singleton; }; std::mt19937 CBaseServerProxy::m_rng( std::random_device { } ( ) ); @@ -996,75 +1162,32 @@ namespace netfilter std::array CBaseServerProxy::m_previous_challenge; std::array CBaseServerProxy::m_challenge; - CBaseServerProxy CBaseServerProxy::Singleton; + std::unique_ptr CBaseServerProxy::Singleton; void Initialize( GarrysMod::Lua::ILuaBase *LUA ) { - if( !server_loader.IsValid( ) ) - LUA->ThrowError( "unable to get server factory" ); - - { - ICvar *icvar = InterfacePointers::Cvar( ); - if( icvar != nullptr ) - { - sv_visiblemaxplayers = icvar->FindVar( "sv_visiblemaxplayers" ); - sv_location = icvar->FindVar( "sv_location" ); - } - - if( sv_visiblemaxplayers == nullptr ) - ConColorMsg( Color( 255, 255, 0, 255 ), "[ServerSecure] Failed to get \"sv_visiblemaxplayers\" convar!\n" ); - - if( sv_location == nullptr ) - ConColorMsg( Color( 255, 255, 0, 255 ), "[ServerSecure] Failed to get \"sv_location\" convar!\n" ); - } - - gamedll = InterfacePointers::ServerGameDLL( ); - if( gamedll == nullptr ) - LUA->ThrowError( "failed to load required IServerGameDLL interface" ); + LUA->GetField( GarrysMod::Lua::INDEX_GLOBAL, "VERSION" ); + const char *game_version = LUA->CheckString( -1 ); - engine_server = InterfacePointers::VEngineServer( ); - if( engine_server == nullptr ) - LUA->ThrowError( "failed to load required IVEngineServer interface" ); - - filesystem = InterfacePointers::FileSystem( ); - if( filesystem == nullptr ) - LUA->ThrowError( "failed to initialize IFileSystem" ); - - CBaseServer *baseserver = static_cast( InterfacePointers::Server( ) ); - if( baseserver != nullptr ) + bool errored = false; + try { - CBaseServerProxy::Singleton.Initialize( baseserver ); - CBaseServerProxy::Singleton.Hook( &CBaseServer::CheckChallengeNr, &CBaseServerProxy::CheckChallengeNr ); - CBaseServerProxy::Singleton.Hook( &CBaseServer::GetChallengeNr, &CBaseServerProxy::GetChallengeNr ); + Core::Singleton = std::make_unique( game_version ); } - + catch( const std::exception &e ) { - const FunctionPointers::GMOD_GetNetSocket_t GetNetSocket = FunctionPointers::GMOD_GetNetSocket( ); - if( GetNetSocket != nullptr ) - { - const netsocket_t *net_socket = GetNetSocket( 1 ); - if( net_socket != nullptr ) - game_socket = net_socket->hUDP; - } + errored = true; + LUA->PushString( e.what( ) ); } - if( game_socket == INVALID_SOCKET ) - LUA->ThrowError( "got an invalid server socket" ); - - if( !recvfrom_hook.Enable( ) ) - LUA->ThrowError( "failed to detour recvfrom" ); + if( errored ) + LUA->Error( ); - threaded_socket_execute = true; - threaded_socket_handle = CreateSimpleThread( PacketReceiverThread, nullptr ); - if( threaded_socket_handle == nullptr ) - LUA->ThrowError( "unable to create thread" ); + LUA->Pop( 1 ); - { - LUA->GetField( GarrysMod::Lua::INDEX_GLOBAL, "VERSION" ); - const char *game_version = LUA->CheckString( -1 ); - BuildStaticReplyInfo( game_version ); - LUA->Pop( 1 ); - } + CBaseServer *baseserver = static_cast( InterfacePointers::Server( ) ); + if( baseserver != nullptr ) + CBaseServerProxy::Singleton = std::make_unique( baseserver ); LUA->PushCFunction( EnableFirewallWhitelist ); LUA->SetField( -2, "EnableFirewallWhitelist" ); @@ -1121,16 +1244,9 @@ namespace netfilter LUA->SetField( -2, "GetSamplePacket" ); } - void Deinitialize( GarrysMod::Lua::ILuaBase * ) + void Deinitialize( ) { - if( threaded_socket_handle != nullptr ) - { - threaded_socket_execute = false; - ThreadJoin( threaded_socket_handle ); - ReleaseThreadHandle( threaded_socket_handle ); - threaded_socket_handle = nullptr; - } - - recvfrom_hook.Disable( ); + CBaseServerProxy::Singleton.reset( ); + Core::Singleton.reset( ); } } diff --git a/source/netfilter/core.hpp b/source/netfilter/core.hpp index 41c866e..20c6a90 100644 --- a/source/netfilter/core.hpp +++ b/source/netfilter/core.hpp @@ -11,5 +11,5 @@ namespace GarrysMod namespace netfilter { void Initialize( GarrysMod::Lua::ILuaBase *LUA ); - void Deinitialize( GarrysMod::Lua::ILuaBase *LUA ); + void Deinitialize( ); }