diff --git a/Libraries/Core/Helpers/NotificationSendHelper.cs b/Libraries/Core/Helpers/NotificationSendHelper.cs index 4f483932..4d0fc599 100644 --- a/Libraries/Core/Helpers/NotificationSendHelper.cs +++ b/Libraries/Core/Helpers/NotificationSendHelper.cs @@ -11,7 +11,7 @@ namespace Core.Helpers; [Injectable] public class NotificationSendHelper( - IWebSocketConnectionHandler _sptWebSocketConnectionHandler, + SptWebSocketConnectionHandler _sptWebSocketConnectionHandler, HashUtil _hashUtil, SaveServer _saveServer, NotificationService _notificationService, diff --git a/Libraries/Core/Servers/HttpServer.cs b/Libraries/Core/Servers/HttpServer.cs index da98af85..e61b0c17 100644 --- a/Libraries/Core/Servers/HttpServer.cs +++ b/Libraries/Core/Servers/HttpServer.cs @@ -29,12 +29,16 @@ public class HttpServer( var app = builder?.Build(); // enable web socket - app?.UseWebSockets(); + app?.UseWebSockets(new WebSocketOptions + { + // Every minute a heartbeat is sent to keep the connection alive. + KeepAliveInterval = TimeSpan.FromSeconds(60) + }); app?.Use( (HttpContext req, RequestDelegate _) => { - return Task.Factory.StartNew(() => HandleFallback(req)); + return Task.Factory.StartNew(async () => await HandleFallback(req)); } ); started = true; @@ -46,11 +50,12 @@ public class HttpServer( _applicationContext.AddValue(ContextVariableType.WEB_APPLICATION, app); } - private Task HandleFallback(HttpContext context) + private async Task HandleFallback(HttpContext context) { if (context.WebSockets.IsWebSocketRequest) { - return _webSocketServer.OnConnection(context); + await _webSocketServer.OnConnection(context); + return; } context.Request.Cookies.TryGetValue("PHPSESSID", out var sessionId); @@ -103,8 +108,6 @@ public class HttpServer( _httpListeners.SingleOrDefault(l => l.CanHandle(sessionId, context.Request))?.Handle(sessionId, context.Request, context.Response); // This http request would be passed through the SPT Router and handled by an ICallback - - return Task.CompletedTask; } private bool? IsLocalRequest(string? remoteAddress) diff --git a/Libraries/Core/Servers/WebSocketServer.cs b/Libraries/Core/Servers/WebSocketServer.cs index edb51912..7a01edab 100644 --- a/Libraries/Core/Servers/WebSocketServer.cs +++ b/Libraries/Core/Servers/WebSocketServer.cs @@ -1,7 +1,6 @@ using System.Net.WebSockets; using Core.Models.Utils; using Core.Servers.Ws; -using Core.Utils; using SptCommon.Annotations; namespace Core.Servers; @@ -9,37 +8,68 @@ namespace Core.Servers; [Injectable(InjectionType.Singleton)] public class WebSocketServer( IEnumerable _webSocketConnectionHandler, - ISptLogger _logger, - JsonUtil _jsonUtil + ISptLogger _logger ) { public async Task OnConnection(HttpContext httpContext) { var socket = await httpContext.WebSockets.AcceptWebSocketAsync(); - await HandleCommunication(httpContext, socket); + await HandleWebSocket(httpContext, socket); } - private Task HandleCommunication(HttpContext context, WebSocket webSocket) + private async Task HandleWebSocket(HttpContext context, WebSocket webSocket) { var socketHandlers = _webSocketConnectionHandler .Where(wsh => context.Request.Path.Value.Contains(wsh.GetHookUrl())) .ToList(); + if (socketHandlers.Count == 0) { var message = $"Socket connection received for url {context.Request.Path.Value}, but there is not websocket handler configured for it"; - _logger.Warning(message); - return webSocket.CloseAsync(WebSocketCloseStatus.ProtocolError, message, CancellationToken.None); + await webSocket.CloseAsync(WebSocketCloseStatus.ProtocolError, message, CancellationToken.None); + return; } foreach (var wsh in socketHandlers) { - wsh.OnConnection(webSocket, context).Wait(); if (webSocket.State == WebSocketState.Open) { _logger.Info($"WebSocketHandler \"{wsh.GetSocketId()}\" connected"); } + + await wsh.OnConnection(webSocket, context); } - return Task.CompletedTask; + var messageBuffer = new byte[1024]; + + try + { + while (webSocket.State == WebSocketState.Open) + { + var receiveResult = await webSocket.ReceiveAsync(new ArraySegment(messageBuffer), CancellationToken.None); + + if (receiveResult.MessageType == WebSocketMessageType.Text || receiveResult.MessageType == WebSocketMessageType.Binary) + { + foreach (var wsh in socketHandlers) + { + await wsh.OnMessage(messageBuffer.ToArray(), receiveResult.MessageType, webSocket, context); + } + } + else if (receiveResult.MessageType == WebSocketMessageType.Close) + { + foreach (var wsh in socketHandlers) + { + await wsh.OnClose(webSocket, context); + } + } + } + } + catch (Exception) + { + foreach (var wsh in socketHandlers) + { + await wsh.OnClose(webSocket, context); + } + } } } diff --git a/Libraries/Core/Servers/Ws/IWebSocketConnectionHandler.cs b/Libraries/Core/Servers/Ws/IWebSocketConnectionHandler.cs index 554a9cd7..374d4c11 100644 --- a/Libraries/Core/Servers/Ws/IWebSocketConnectionHandler.cs +++ b/Libraries/Core/Servers/Ws/IWebSocketConnectionHandler.cs @@ -1,5 +1,4 @@ using System.Net.WebSockets; -using Core.Models.Eft.Ws; namespace Core.Servers.Ws; @@ -8,7 +7,6 @@ public interface IWebSocketConnectionHandler string GetHookUrl(); string GetSocketId(); Task OnConnection(WebSocket ws, HttpContext context); - bool IsWebSocketConnected(string sessionId); - - void SendMessage(string sessionID, WsNotificationEvent output); + Task OnMessage(byte[] rawData, WebSocketMessageType messageType, WebSocket ws, HttpContext context); + Task OnClose(WebSocket ws, HttpContext context); } diff --git a/Libraries/Core/Servers/Ws/SptWebSocketConnectionHandler.cs b/Libraries/Core/Servers/Ws/SptWebSocketConnectionHandler.cs index 4f595c92..4f2e0ba8 100644 --- a/Libraries/Core/Servers/Ws/SptWebSocketConnectionHandler.cs +++ b/Libraries/Core/Servers/Ws/SptWebSocketConnectionHandler.cs @@ -2,7 +2,6 @@ using System.Net.WebSockets; using System.Text; using Core.Helpers; using Core.Models.Eft.Ws; -using Core.Models.Spt.Config; using Core.Models.Utils; using Core.Servers.Ws.Message; using Core.Services; @@ -24,9 +23,6 @@ public class SptWebSocketConnectionHandler( { protected WsPing _defaultNotification = new(); protected Lock _lockObject = new(); - protected Dictionary _receiveTasks = new(); - protected Dictionary _socketAliveTimers = new(); - protected Dictionary _sockets = new(); public string GetHookUrl() @@ -41,53 +37,48 @@ public class SptWebSocketConnectionHandler( public Task OnConnection(WebSocket ws, HttpContext context) { - return Task.Factory.StartNew( - () => - { - var splitUrl = context.Request.Path.Value.Split("/"); - var sessionID = splitUrl.Last(); - var playerProfile = _profileHelper.GetFullProfile(sessionID); - var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})"; + var splitUrl = context.Request.Path.Value.Split("/"); + var sessionID = splitUrl.Last(); + var playerProfile = _profileHelper.GetFullProfile(sessionID); + var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})"; - _logger.Info(_localisationService.GetText("websocket-player_connected", playerInfoText)); + _logger.Info(_localisationService.GetText("websocket-player_connected", playerInfoText)); - _sockets.Add(sessionID, ws); + lock (_lockObject) + { + _sockets.Add(sessionID, ws); + } - lock (_lockObject) - { - _receiveTasks.Add(sessionID, new CancellationTokenSource()); - var cancelToken = _receiveTasks[sessionID].Token; - Task.Factory.StartNew(_ => ReceiveTask(sessionID, ws, cancelToken), null, cancelToken); - } + return Task.CompletedTask; + } - while (ws.State == WebSocketState.Open) - { - Thread.Sleep(1000); - } + public async Task OnMessage(byte[] receivedMessage, WebSocketMessageType messageType, WebSocket ws, HttpContext context) + { + var splitUrl = context.Request.Path.Value.Split("/"); + var sessionID = splitUrl.Last(); + var playerProfile = _profileHelper.GetFullProfile(sessionID); + var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})"; - // Once the websocket dies, we dispose of it - //_logger.Debug(_localisationService.GetText("websocket-socket_lost_deleting_handle")); - // this is expected and relayed via "Player has disconnected" i dont think this is needed - lock (_lockObject) - { - if (_socketAliveTimers.TryGetValue(sessionID, out var timer)) - { - timer.Change(Timeout.Infinite, Timeout.Infinite); - _socketAliveTimers.Remove(sessionID); - } + foreach (var sptWebSocketMessageHandler in _messageHandlers) + { + await sptWebSocketMessageHandler.OnSptMessage(sessionID, ws, receivedMessage); + } + } - if (_sockets.ContainsKey(sessionID)) - { - _sockets.Remove(sessionID); - } + public async Task OnClose(WebSocket ws, HttpContext context) + { + var splitUrl = context.Request.Path.Value.Split("/"); + var sessionID = splitUrl.Last(); - if (_receiveTasks.TryGetValue(sessionID, out var receiveTask)) - { - receiveTask.CancelAsync().Wait(); - } - } - } - ); + lock (_lockObject) + { + _sockets.Remove(sessionID); + var playerProfile = _profileHelper.GetFullProfile(sessionID); + var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})"; + _logger.Info($"[ws] player: {playerInfoText} has disconnected"); + } + + await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "Client closed connection", CancellationToken.None); } public void SendMessage(string sessionID, WsNotificationEvent output) @@ -129,58 +120,6 @@ public class SptWebSocketConnectionHandler( return _sockets.TryGetValue(sessionID, out var socket) && socket.State == WebSocketState.Open; } - private void ReceiveTask(string sessionID, WebSocket ws, CancellationToken cancelToken) - { - List readBytes = new(); - while (ws.State == WebSocketState.Open) - { - try - { - if (cancelToken.IsCancellationRequested) - { - break; - } - - var isEndOfMessage = false; - while (!isEndOfMessage) - { - var buffer = new ArraySegment(new byte[1024 * 4]); - var readTask = ws.ReceiveAsync(buffer, cancelToken); - readTask.Wait(cancelToken); - readBytes.AddRange(buffer); - isEndOfMessage = readTask.Result.EndOfMessage; - } - - foreach (var sptWebSocketMessageHandler in _messageHandlers) - { - sptWebSocketMessageHandler.OnSptMessage(sessionID, ws, readBytes.ToArray()).Wait(); - } - } - catch (OperationCanceledException _) - { - _logger.Info("WebSocket disconnecting, receive task finalized..."); - } - catch (Exception _) - { - lock (_lockObject) - { - _sockets.Remove(sessionID); - _socketAliveTimers.Remove(sessionID); - _receiveTasks.Remove(sessionID); - var playerProfile = _profileHelper.GetFullProfile(sessionID); - var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})"; - _logger.Info($"[ws] player: {playerInfoText} has disconnected"); - } - - ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "Client closed connection", CancellationToken.None); - } - finally - { - readBytes.Clear(); - } - } - } - public WebSocket GetSessionWebSocket(string sessionID) { return _sockets[sessionID];