Refactor websockets to be easier for users to work with

- will test later today
- example will need updating
This commit is contained in:
Archangel
2025-02-17 12:02:20 +01:00
parent 5885e141e4
commit 6d7cdf1f3b
5 changed files with 86 additions and 116 deletions
@@ -11,7 +11,7 @@ namespace Core.Helpers;
[Injectable]
public class NotificationSendHelper(
IWebSocketConnectionHandler _sptWebSocketConnectionHandler,
SptWebSocketConnectionHandler _sptWebSocketConnectionHandler,
HashUtil _hashUtil,
SaveServer _saveServer,
NotificationService _notificationService,
+9 -6
View File
@@ -29,12 +29,16 @@ public class HttpServer(
var app = builder?.Build();
// enable web socket
app?.UseWebSockets();
app?.UseWebSockets(new WebSocketOptions
{
// Every minute a heartbeat is sent to keep the connection alive.
KeepAliveInterval = TimeSpan.FromSeconds(60)
});
app?.Use(
(HttpContext req, RequestDelegate _) =>
{
return Task.Factory.StartNew(() => HandleFallback(req));
return Task.Factory.StartNew(async () => await HandleFallback(req));
}
);
started = true;
@@ -46,11 +50,12 @@ public class HttpServer(
_applicationContext.AddValue(ContextVariableType.WEB_APPLICATION, app);
}
private Task HandleFallback(HttpContext context)
private async Task HandleFallback(HttpContext context)
{
if (context.WebSockets.IsWebSocketRequest)
{
return _webSocketServer.OnConnection(context);
await _webSocketServer.OnConnection(context);
return;
}
context.Request.Cookies.TryGetValue("PHPSESSID", out var sessionId);
@@ -103,8 +108,6 @@ public class HttpServer(
_httpListeners.SingleOrDefault(l => l.CanHandle(sessionId, context.Request))?.Handle(sessionId, context.Request, context.Response);
// This http request would be passed through the SPT Router and handled by an ICallback
return Task.CompletedTask;
}
private bool? IsLocalRequest(string? remoteAddress)
+39 -9
View File
@@ -1,7 +1,6 @@
using System.Net.WebSockets;
using Core.Models.Utils;
using Core.Servers.Ws;
using Core.Utils;
using SptCommon.Annotations;
namespace Core.Servers;
@@ -9,37 +8,68 @@ namespace Core.Servers;
[Injectable(InjectionType.Singleton)]
public class WebSocketServer(
IEnumerable<IWebSocketConnectionHandler> _webSocketConnectionHandler,
ISptLogger<WebSocketServer> _logger,
JsonUtil _jsonUtil
ISptLogger<WebSocketServer> _logger
)
{
public async Task OnConnection(HttpContext httpContext)
{
var socket = await httpContext.WebSockets.AcceptWebSocketAsync();
await HandleCommunication(httpContext, socket);
await HandleWebSocket(httpContext, socket);
}
private Task HandleCommunication(HttpContext context, WebSocket webSocket)
private async Task HandleWebSocket(HttpContext context, WebSocket webSocket)
{
var socketHandlers = _webSocketConnectionHandler
.Where(wsh => context.Request.Path.Value.Contains(wsh.GetHookUrl()))
.ToList();
if (socketHandlers.Count == 0)
{
var message = $"Socket connection received for url {context.Request.Path.Value}, but there is not websocket handler configured for it";
_logger.Warning(message);
return webSocket.CloseAsync(WebSocketCloseStatus.ProtocolError, message, CancellationToken.None);
await webSocket.CloseAsync(WebSocketCloseStatus.ProtocolError, message, CancellationToken.None);
return;
}
foreach (var wsh in socketHandlers)
{
wsh.OnConnection(webSocket, context).Wait();
if (webSocket.State == WebSocketState.Open)
{
_logger.Info($"WebSocketHandler \"{wsh.GetSocketId()}\" connected");
}
await wsh.OnConnection(webSocket, context);
}
return Task.CompletedTask;
var messageBuffer = new byte[1024];
try
{
while (webSocket.State == WebSocketState.Open)
{
var receiveResult = await webSocket.ReceiveAsync(new ArraySegment<byte>(messageBuffer), CancellationToken.None);
if (receiveResult.MessageType == WebSocketMessageType.Text || receiveResult.MessageType == WebSocketMessageType.Binary)
{
foreach (var wsh in socketHandlers)
{
await wsh.OnMessage(messageBuffer.ToArray(), receiveResult.MessageType, webSocket, context);
}
}
else if (receiveResult.MessageType == WebSocketMessageType.Close)
{
foreach (var wsh in socketHandlers)
{
await wsh.OnClose(webSocket, context);
}
}
}
}
catch (Exception)
{
foreach (var wsh in socketHandlers)
{
await wsh.OnClose(webSocket, context);
}
}
}
}
@@ -1,5 +1,4 @@
using System.Net.WebSockets;
using Core.Models.Eft.Ws;
namespace Core.Servers.Ws;
@@ -8,7 +7,6 @@ public interface IWebSocketConnectionHandler
string GetHookUrl();
string GetSocketId();
Task OnConnection(WebSocket ws, HttpContext context);
bool IsWebSocketConnected(string sessionId);
void SendMessage(string sessionID, WsNotificationEvent output);
Task OnMessage(byte[] rawData, WebSocketMessageType messageType, WebSocket ws, HttpContext context);
Task OnClose(WebSocket ws, HttpContext context);
}
@@ -2,7 +2,6 @@ using System.Net.WebSockets;
using System.Text;
using Core.Helpers;
using Core.Models.Eft.Ws;
using Core.Models.Spt.Config;
using Core.Models.Utils;
using Core.Servers.Ws.Message;
using Core.Services;
@@ -24,9 +23,6 @@ public class SptWebSocketConnectionHandler(
{
protected WsPing _defaultNotification = new();
protected Lock _lockObject = new();
protected Dictionary<string, CancellationTokenSource> _receiveTasks = new();
protected Dictionary<string, Timer> _socketAliveTimers = new();
protected Dictionary<string, WebSocket> _sockets = new();
public string GetHookUrl()
@@ -41,53 +37,48 @@ public class SptWebSocketConnectionHandler(
public Task OnConnection(WebSocket ws, HttpContext context)
{
return Task.Factory.StartNew(
() =>
{
var splitUrl = context.Request.Path.Value.Split("/");
var sessionID = splitUrl.Last();
var playerProfile = _profileHelper.GetFullProfile(sessionID);
var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})";
var splitUrl = context.Request.Path.Value.Split("/");
var sessionID = splitUrl.Last();
var playerProfile = _profileHelper.GetFullProfile(sessionID);
var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})";
_logger.Info(_localisationService.GetText("websocket-player_connected", playerInfoText));
_logger.Info(_localisationService.GetText("websocket-player_connected", playerInfoText));
_sockets.Add(sessionID, ws);
lock (_lockObject)
{
_sockets.Add(sessionID, ws);
}
lock (_lockObject)
{
_receiveTasks.Add(sessionID, new CancellationTokenSource());
var cancelToken = _receiveTasks[sessionID].Token;
Task.Factory.StartNew(_ => ReceiveTask(sessionID, ws, cancelToken), null, cancelToken);
}
return Task.CompletedTask;
}
while (ws.State == WebSocketState.Open)
{
Thread.Sleep(1000);
}
public async Task OnMessage(byte[] receivedMessage, WebSocketMessageType messageType, WebSocket ws, HttpContext context)
{
var splitUrl = context.Request.Path.Value.Split("/");
var sessionID = splitUrl.Last();
var playerProfile = _profileHelper.GetFullProfile(sessionID);
var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})";
// Once the websocket dies, we dispose of it
//_logger.Debug(_localisationService.GetText("websocket-socket_lost_deleting_handle"));
// this is expected and relayed via "Player has disconnected" i dont think this is needed
lock (_lockObject)
{
if (_socketAliveTimers.TryGetValue(sessionID, out var timer))
{
timer.Change(Timeout.Infinite, Timeout.Infinite);
_socketAliveTimers.Remove(sessionID);
}
foreach (var sptWebSocketMessageHandler in _messageHandlers)
{
await sptWebSocketMessageHandler.OnSptMessage(sessionID, ws, receivedMessage);
}
}
if (_sockets.ContainsKey(sessionID))
{
_sockets.Remove(sessionID);
}
public async Task OnClose(WebSocket ws, HttpContext context)
{
var splitUrl = context.Request.Path.Value.Split("/");
var sessionID = splitUrl.Last();
if (_receiveTasks.TryGetValue(sessionID, out var receiveTask))
{
receiveTask.CancelAsync().Wait();
}
}
}
);
lock (_lockObject)
{
_sockets.Remove(sessionID);
var playerProfile = _profileHelper.GetFullProfile(sessionID);
var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})";
_logger.Info($"[ws] player: {playerInfoText} has disconnected");
}
await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "Client closed connection", CancellationToken.None);
}
public void SendMessage(string sessionID, WsNotificationEvent output)
@@ -129,58 +120,6 @@ public class SptWebSocketConnectionHandler(
return _sockets.TryGetValue(sessionID, out var socket) && socket.State == WebSocketState.Open;
}
private void ReceiveTask(string sessionID, WebSocket ws, CancellationToken cancelToken)
{
List<byte> readBytes = new();
while (ws.State == WebSocketState.Open)
{
try
{
if (cancelToken.IsCancellationRequested)
{
break;
}
var isEndOfMessage = false;
while (!isEndOfMessage)
{
var buffer = new ArraySegment<byte>(new byte[1024 * 4]);
var readTask = ws.ReceiveAsync(buffer, cancelToken);
readTask.Wait(cancelToken);
readBytes.AddRange(buffer);
isEndOfMessage = readTask.Result.EndOfMessage;
}
foreach (var sptWebSocketMessageHandler in _messageHandlers)
{
sptWebSocketMessageHandler.OnSptMessage(sessionID, ws, readBytes.ToArray()).Wait();
}
}
catch (OperationCanceledException _)
{
_logger.Info("WebSocket disconnecting, receive task finalized...");
}
catch (Exception _)
{
lock (_lockObject)
{
_sockets.Remove(sessionID);
_socketAliveTimers.Remove(sessionID);
_receiveTasks.Remove(sessionID);
var playerProfile = _profileHelper.GetFullProfile(sessionID);
var playerInfoText = $"{playerProfile.ProfileInfo.Username} ({sessionID})";
_logger.Info($"[ws] player: {playerInfoText} has disconnected");
}
ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "Client closed connection", CancellationToken.None);
}
finally
{
readBytes.Clear();
}
}
}
public WebSocket GetSessionWebSocket(string sessionID)
{
return _sockets[sessionID];