diff --git a/Core/Helpers/NotificationSendHelper.cs b/Core/Helpers/NotificationSendHelper.cs index 03b5ad1e..24bd88df 100644 --- a/Core/Helpers/NotificationSendHelper.cs +++ b/Core/Helpers/NotificationSendHelper.cs @@ -39,7 +39,7 @@ public class NotificationSendHelper { if (_sptWebSocketConnectionHandler.IsWebSocketConnected(sessionID)) { - _sptWebSocketConnectionHandler.SendMessageAsync(sessionID, notificationMessage).Wait(); + _sptWebSocketConnectionHandler.SendMessage(sessionID, notificationMessage); } else { diff --git a/Core/Servers/WebSocketServer.cs b/Core/Servers/WebSocketServer.cs index 4cbcdd05..7038aeb7 100644 --- a/Core/Servers/WebSocketServer.cs +++ b/Core/Servers/WebSocketServer.cs @@ -25,30 +25,22 @@ public class WebSocketServer _jsonUtil = jsonUtil; } - public Task OnConnection(HttpContext httpContext) + public async Task OnConnection(HttpContext httpContext) { - return httpContext.WebSockets.AcceptWebSocketAsync() - .ContinueWith(task => HandleCommunication(httpContext, task.Result)); + var socket = await httpContext.WebSockets.AcceptWebSocketAsync(); + await HandleCommunication(httpContext, socket); } - private void HandleCommunication(HttpContext context, WebSocket webSocket) + private Task HandleCommunication(HttpContext context, WebSocket webSocket) { var socketHandlers = _webSocketConnectionHandler .Where(wsh => context.Request.Path.Value.Contains(wsh.GetHookUrl())) .ToList(); if (socketHandlers.Count == 0) { - var message = - $"Socket connection received for url {context.Request.Path.Value}, but there is not websocket handler configured for it"; + var message = $"Socket connection received for url {context.Request.Path.Value}, but there is not websocket handler configured for it"; _logger.Warning(message); - webSocket.SendAsync( - Encoding.UTF8.GetBytes(_jsonUtil.Serialize(new { error = message })), - WebSocketMessageType.Text, - true, - CancellationToken.None - ) - .Wait(); - webSocket.CloseAsync(WebSocketCloseStatus.ProtocolError, message, CancellationToken.None).Wait(); + return webSocket.CloseAsync(WebSocketCloseStatus.ProtocolError, message, CancellationToken.None); } foreach (var wsh in socketHandlers) @@ -56,5 +48,6 @@ public class WebSocketServer wsh.OnConnection(webSocket, context).Wait(); _logger.Info($"WebSocketHandler \"{wsh.GetSocketId()}\" connected"); } + return Task.CompletedTask; } } diff --git a/Core/Servers/Ws/IWebSocketConnectionHandler.cs b/Core/Servers/Ws/IWebSocketConnectionHandler.cs index 07efc9e5..554a9cd7 100644 --- a/Core/Servers/Ws/IWebSocketConnectionHandler.cs +++ b/Core/Servers/Ws/IWebSocketConnectionHandler.cs @@ -10,5 +10,5 @@ public interface IWebSocketConnectionHandler Task OnConnection(WebSocket ws, HttpContext context); bool IsWebSocketConnected(string sessionId); - Task SendMessageAsync(string sessionID, WsNotificationEvent output); + void SendMessage(string sessionID, WsNotificationEvent output); } diff --git a/Core/Servers/Ws/SptWebSocketConnectionHandler.cs b/Core/Servers/Ws/SptWebSocketConnectionHandler.cs index 3497c4f2..af9e1998 100644 --- a/Core/Servers/Ws/SptWebSocketConnectionHandler.cs +++ b/Core/Servers/Ws/SptWebSocketConnectionHandler.cs @@ -17,6 +17,8 @@ public class SptWebSocketConnectionHandler : IWebSocketConnectionHandler protected Dictionary _sockets = new(); protected Dictionary _socketAliveTimers = new(); protected Dictionary _receiveTasks = new(); + protected object _lockObject = new(); + protected ISptLogger _logger; protected LocalisationService _localisationService; protected JsonUtil _jsonUtil; @@ -48,56 +50,75 @@ public class SptWebSocketConnectionHandler : IWebSocketConnectionHandler public Task OnConnection(WebSocket ws, HttpContext context) { - var splitUrl = context.Request.Path.Value.Split("/"); - var sessionID = splitUrl.Last(); - var playerProfile = _profileHelper.GetFullProfile(sessionID); - var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})"; + return Task.Factory.StartNew( + () => + { + var splitUrl = context.Request.Path.Value.Split("/"); + var sessionID = splitUrl.Last(); + var playerProfile = _profileHelper.GetFullProfile(sessionID); + var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})"; - _logger.Info(_localisationService.GetText("websocket-player_connected", playerInfoText)); + _logger.Info(_localisationService.GetText("websocket-player_connected", playerInfoText)); - _sockets.Add(sessionID, ws); + _sockets.Add(sessionID, ws); - _socketAliveTimers.Add( - sessionID, - new Timer( - async (_) => + lock (_lockObject) { - _logger.Debug(_localisationService.GetText("websocket-pinging_player", sessionID)); + _socketAliveTimers.Add( + sessionID, + new Timer( + _ => { TimedTask(ws, sessionID); }, + null, + TimeSpan.Zero, + TimeSpan.FromMilliseconds(_httpConfig.WebSocketPingDelayMs) + ) + ); + } - if (ws.State == WebSocketState.Open) - { - await ws.SendAsync( - Encoding.UTF8.GetBytes(_jsonUtil.Serialize(_defaultNotification)), - WebSocketMessageType.Text, - true, - CancellationToken.None - ); - } - else - { - _logger.Debug(_localisationService.GetText("websocket-socket_lost_deleting_handle")); - var timer = _socketAliveTimers[sessionID]; - timer.Change(Timeout.Infinite, Timeout.Infinite); - _socketAliveTimers.Remove(sessionID); - _sockets.Remove(sessionID); - var receiveTask = _receiveTasks[sessionID]; - await receiveTask.CancelAsync(); - } - }, - null, - TimeSpan.Zero, - TimeSpan.FromMilliseconds(_httpConfig.WebSocketPingDelayMs) - ) + lock (_lockObject) + { + _receiveTasks.Add(sessionID, new()); + var cancelToken = _receiveTasks[sessionID].Token; + Task.Factory.StartNew(_ => ReceiveTask(sessionID, ws, cancelToken), null, cancelToken); + } + + while (ws.State == WebSocketState.Open) + { + Thread.Sleep(1000); + } + + // Once the websocket dies, we dispose of it + _logger.Debug(_localisationService.GetText("websocket-socket_lost_deleting_handle")); + lock (_lockObject) + { + var timer = _socketAliveTimers[sessionID]; + timer.Change(Timeout.Infinite, Timeout.Infinite); + _socketAliveTimers.Remove(sessionID); + _sockets.Remove(sessionID); + var receiveTask = _receiveTasks[sessionID]; + receiveTask.CancelAsync().Wait(); + } + } ); - - _receiveTasks.Add(sessionID, new()); - var cancelToken = _receiveTasks[sessionID].Token; - Task.Factory.StartNew((_) => ReceiveTask(sessionID, ws, cancelToken), null, cancelToken); - - return Task.CompletedTask; } - public async Task SendMessageAsync(string sessionID, WsNotificationEvent output) + private void TimedTask(WebSocket ws, string sessionID) + { + _logger.Debug(_localisationService.GetText("websocket-pinging_player", sessionID)); + + if (ws.State == WebSocketState.Open) + { + var sendTask = ws.SendAsync( + Encoding.UTF8.GetBytes(_jsonUtil.Serialize(_defaultNotification)), + WebSocketMessageType.Text, + true, + CancellationToken.None + ); + sendTask.Wait(); + } + } + + public void SendMessage(string sessionID, WsNotificationEvent output) { try { @@ -105,12 +126,13 @@ public class SptWebSocketConnectionHandler : IWebSocketConnectionHandler { var ws = GetSessionWebSocket(sessionID); - await ws.SendAsync( + var sendTask = ws.SendAsync( Encoding.UTF8.GetBytes(_jsonUtil.Serialize(output)), WebSocketMessageType.Text, true, CancellationToken.None ); + sendTask.Wait(); _logger.Debug(_localisationService.GetText("websocket-message_sent")); } else @@ -142,6 +164,7 @@ public class SptWebSocketConnectionHandler : IWebSocketConnectionHandler readBytes.AddRange(buffer); isEndOfMessage = readTask.Result.EndOfMessage; } + foreach (var sptWebSocketMessageHandler in _messageHandlers) { sptWebSocketMessageHandler.OnSptMessage(sessionID, ws, readBytes.ToArray()).Wait(); diff --git a/Server/Program.cs b/Server/Program.cs index c0c17575..bb392c1d 100644 --- a/Server/Program.cs +++ b/Server/Program.cs @@ -71,6 +71,9 @@ public static class Program builder.Logging.ClearProviders(); logger = new LoggerConfiguration() .ReadFrom.Configuration(builder.Configuration) + # if DEBUG + .MinimumLevel.Debug() + # endif .MinimumLevel.Override("Microsoft.AspNetCore.Hosting.Diagnostics", LogEventLevel.Warning) .Enrich.FromLogContext() .Enrich.WithThreadId()