From a58065871f84ec785991bf8650655aa80b9283f8 Mon Sep 17 00:00:00 2001 From: clodanSPT Date: Wed, 28 May 2025 11:42:56 +0100 Subject: [PATCH] =?UTF-8?q?Removed=20ConcurrentDictionary=20in=20favor=20o?= =?UTF-8?q?f=20locked=20dictionary,=20and=20added=E2=80=A6=20(#287)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Removed ConcurrentDictionary in favor of locked dictionary, and added reconnection behaviour to OnConnection * Refactored code to handle multiple ws connected at the same time and offer graceful disconnection of appropriate sockets * removed unused usings --------- Co-authored-by: Alex Co-authored-by: Chomp <27521899+chompDev@users.noreply.github.com> --- .../Assets/database/locales/server/en.json | 1 + .../Servers/WebSocketServer.cs | 6 +- .../Servers/Ws/IWebSocketConnectionHandler.cs | 4 +- .../Ws/SptWebSocketConnectionHandler.cs | 114 +++++++++++++----- 4 files changed, 88 insertions(+), 37 deletions(-) diff --git a/Libraries/SPTarkov.Server.Assets/Assets/database/locales/server/en.json b/Libraries/SPTarkov.Server.Assets/Assets/database/locales/server/en.json index 6857cba2..cc92f3ba 100644 --- a/Libraries/SPTarkov.Server.Assets/Assets/database/locales/server/en.json +++ b/Libraries/SPTarkov.Server.Assets/Assets/database/locales/server/en.json @@ -722,6 +722,7 @@ "websocket-not_ready_message_not_sent": "[WS] Socket not ready for %s, message not sent", "websocket-pinging_player": "[WS] Pinging player: %s", "websocket-player_connected": "[WS] Player: %s has connected", + "websocket-player_reconnect": "[WS] Player: %s reconnection received, closing previous active socket", "websocket-received_message": "[WS] Received message from user %s ", "websocket-socket_lost_deleting_handle": "[WS] Socket lost, deleting handle", "websocket-started": "Started websocket at %s", diff --git a/Libraries/SPTarkov.Server.Core/Servers/WebSocketServer.cs b/Libraries/SPTarkov.Server.Core/Servers/WebSocketServer.cs index 70bfd44d..5e1bf900 100644 --- a/Libraries/SPTarkov.Server.Core/Servers/WebSocketServer.cs +++ b/Libraries/SPTarkov.Server.Core/Servers/WebSocketServer.cs @@ -34,6 +34,8 @@ public class WebSocketServer( return; } + var sessionIdContext = DateTime.UtcNow.ToString("yyyyMMddHHmmssfff"); + foreach (var wsh in socketHandlers) { if (webSocket.State == WebSocketState.Open) @@ -44,7 +46,7 @@ public class WebSocketServer( } } - await wsh.OnConnection(webSocket, context); + await wsh.OnConnection(webSocket, context, sessionIdContext); } // Discard this task, we dont need to await it. @@ -79,7 +81,7 @@ public class WebSocketServer( foreach (var wsh in socketHandlers) { await cts.CancelAsync(); - await wsh.OnClose(webSocket, context); + await wsh.OnClose(webSocket, context, sessionIdContext); } } } diff --git a/Libraries/SPTarkov.Server.Core/Servers/Ws/IWebSocketConnectionHandler.cs b/Libraries/SPTarkov.Server.Core/Servers/Ws/IWebSocketConnectionHandler.cs index cf002305..5fa80a69 100644 --- a/Libraries/SPTarkov.Server.Core/Servers/Ws/IWebSocketConnectionHandler.cs +++ b/Libraries/SPTarkov.Server.Core/Servers/Ws/IWebSocketConnectionHandler.cs @@ -6,7 +6,7 @@ public interface IWebSocketConnectionHandler { string GetHookUrl(); string GetSocketId(); - Task OnConnection(WebSocket ws, HttpContext context); + Task OnConnection(WebSocket ws, HttpContext context, string sessionIdContext); Task OnMessage(byte[] rawData, WebSocketMessageType messageType, WebSocket ws, HttpContext context); - Task OnClose(WebSocket ws, HttpContext context); + Task OnClose(WebSocket ws, HttpContext context, string sessionIdContext); } diff --git a/Libraries/SPTarkov.Server.Core/Servers/Ws/SptWebSocketConnectionHandler.cs b/Libraries/SPTarkov.Server.Core/Servers/Ws/SptWebSocketConnectionHandler.cs index 9ba230d9..55c3a216 100644 --- a/Libraries/SPTarkov.Server.Core/Servers/Ws/SptWebSocketConnectionHandler.cs +++ b/Libraries/SPTarkov.Server.Core/Servers/Ws/SptWebSocketConnectionHandler.cs @@ -1,4 +1,3 @@ -using System.Collections.Concurrent; using System.Net.WebSockets; using System.Text; using SPTarkov.DI.Annotations; @@ -18,12 +17,11 @@ public class SptWebSocketConnectionHandler( LocalisationService _localisationService, JsonUtil _jsonUtil, ProfileHelper _profileHelper, - ConfigServer _configServer, IEnumerable _messageHandlers ) : IWebSocketConnectionHandler { - protected WsPing _defaultNotification = new(); - protected ConcurrentDictionary _sockets = new(); + protected Dictionary> _sockets = new(); + protected Lock _socketsLock = new(); public string GetHookUrl() { @@ -35,29 +33,57 @@ public class SptWebSocketConnectionHandler( return "SPT WebSocket Handler"; } - public Task OnConnection(WebSocket ws, HttpContext context) + public Task OnConnection(WebSocket ws, HttpContext context, string sessionIdContext) { var splitUrl = context.Request.Path.Value.Split("/"); var sessionID = splitUrl.Last(); var playerProfile = _profileHelper.GetFullProfile(sessionID); var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})"; - - _logger.Info(_localisationService.GetText("websocket-player_connected", playerInfoText)); - - if (!_sockets.TryAdd(sessionID, ws) && _logger.IsLogEnabled(LogLevel.Debug)) + lock (_socketsLock) { - _logger.Debug($"[ws] player: {playerInfoText} has already connected"); - } + if (_sockets.TryGetValue(sessionID, out var sessionSockets) && sessionSockets.Any()) + { + if (_logger.IsLogEnabled(LogLevel.Debug)) + { + _logger.Debug(_localisationService.GetText("websocket-player_reconnect", playerInfoText)); + } - return Task.CompletedTask; + foreach (var oldSocket in sessionSockets) + { + if (_logger.IsLogEnabled(LogLevel.Debug)) + { + _logger.Debug($"[ws] Removing websocket reference {oldSocket.Key} for session {sessionID}"); + } + + oldSocket.Value.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait(); + } + + sessionSockets.Clear(); + } + else + { + sessionSockets = new Dictionary(); + _sockets.Add(sessionID, sessionSockets); + } + + sessionSockets.Add(sessionIdContext, ws); + if (_logger.IsLogEnabled(LogLevel.Info)) + { + _logger.Info(_localisationService.GetText("websocket-player_connected", playerInfoText)); + } + + return Task.CompletedTask; + } } - public async Task OnMessage(byte[] receivedMessage, WebSocketMessageType messageType, WebSocket ws, HttpContext context) + public async Task OnMessage( + byte[] receivedMessage, + WebSocketMessageType messageType, + WebSocket ws, + HttpContext context) { var splitUrl = context.Request.Path.Value.Split("/"); var sessionID = splitUrl.Last(); - var playerProfile = _profileHelper.GetFullProfile(sessionID); - var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})"; foreach (var sptWebSocketMessageHandler in _messageHandlers) { @@ -65,19 +91,31 @@ public class SptWebSocketConnectionHandler( } } - public async Task OnClose(WebSocket ws, HttpContext context) + public async Task OnClose(WebSocket ws, HttpContext context, string sessionIdContext) { var splitUrl = context.Request.Path.Value.Split("/"); var sessionID = splitUrl.Last(); - if (!_sockets.Remove(sessionID, out _) && _logger.IsLogEnabled(LogLevel.Debug)) + lock (_socketsLock) { - _logger.Debug($"[ws] Error removing socket for session: {sessionID}"); + if (_sockets.TryGetValue(sessionID, out var sessionSockets) && sessionSockets.Any()) + { + if (!sessionSockets.TryGetValue(sessionIdContext, out _) && _logger.IsLogEnabled(LogLevel.Info)) + { + _logger.Info($"[ws] The websocket session {sessionID} with reference {sessionIdContext} has already been removed or reconnected"); + } + else + { + sessionSockets.Remove(sessionIdContext); + if (_logger.IsLogEnabled(LogLevel.Info)) + { + var playerProfile = _profileHelper.GetFullProfile(sessionID); + var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})"; + _logger.Info($"[ws] player: {playerInfoText} has disconnected"); + } + } + } } - - var playerProfile = _profileHelper.GetFullProfile(sessionID); - var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})"; - _logger.Info($"[ws] player: {playerInfoText} has disconnected"); } public void SendMessage(string sessionID, WsNotificationEvent output) @@ -86,15 +124,19 @@ public class SptWebSocketConnectionHandler( { if (IsWebSocketConnected(sessionID)) { - var ws = GetSessionWebSocket(sessionID); + var webSockets = GetSessionWebSocket(sessionID); + + foreach (var webSocket in webSockets) + { + var sendTask = webSocket.SendAsync( + Encoding.UTF8.GetBytes(_jsonUtil.Serialize(output, output.GetType())), + WebSocketMessageType.Text, + true, + CancellationToken.None + ); + sendTask.Wait(); + } - var sendTask = ws.SendAsync( - Encoding.UTF8.GetBytes(_jsonUtil.Serialize(output, output.GetType())), - WebSocketMessageType.Text, - true, - CancellationToken.None - ); - sendTask.Wait(); if (_logger.IsLogEnabled(LogLevel.Debug)) { _logger.Debug(_localisationService.GetText("websocket-message_sent")); @@ -116,11 +158,17 @@ public class SptWebSocketConnectionHandler( public bool IsWebSocketConnected(string sessionID) { - return _sockets.TryGetValue(sessionID, out var socket) && socket.State == WebSocketState.Open; + lock (_socketsLock) + { + return _sockets.TryGetValue(sessionID, out var sockets) && sockets.Any(s => s.Value.State == WebSocketState.Open); + } } - public WebSocket GetSessionWebSocket(string sessionID) + public IEnumerable GetSessionWebSocket(string sessionID) { - return _sockets.GetValueOrDefault(sessionID); + lock (_socketsLock) + { + return _sockets.GetValueOrDefault(sessionID)?.Values.Where(s => s.State == WebSocketState.Open); + } } }