Http server refactor (#593)

* Refactor various parts of the HttpListener to support Blazor loading

* Add logging for WebSocket requests

* Add better handling for WebSockets not belonging to SPT

* Remove unecessary check

* Remove check as it's already handled earlier now

* Cleanup

* Set delegate
This commit is contained in:
Jesse
2025-09-20 16:05:23 +02:00
committed by GitHub
parent 3d43e6479c
commit ef6e3b8c3a
8 changed files with 189 additions and 144 deletions
@@ -62,6 +62,8 @@
"chat-unable_to_register_command_already_registered": "Unable to register already registered command: %s",
"client_request": "[Client Request] %s",
"client_request_ip": "[Client Request] {{ip}} {{url}}",
"websocket_request": "[WebSocket Request] %s",
"websocket_request_ip": "[WebSocket Request] {{ip}} {{url}}",
"custom-quest-service_quest_id_already_exists": "A quest with the id: {{questId}} already exists.",
"custom-quest-service_no_languages_for_quest": "No languages have been added for custom quest id: {{questId}}",
"custom-quest-service_no_entries_for_language": "No locale entries have been added for language key: {{languageKey}}, was this intentional?",
@@ -6,5 +6,5 @@ namespace SPTarkov.Server.Core.Servers.Http;
public interface IHttpListener
{
bool CanHandle(MongoId sessionId, HttpRequest req);
Task Handle(MongoId sessionId, HttpRequest req, HttpResponse resp);
Task Handle(MongoId sessionId, RequestDelegate next, HttpContext context);
}
@@ -32,14 +32,21 @@ public class SptHttpListener(
return SupportedMethods.Contains(req.Method);
}
public async Task Handle(MongoId sessionId, HttpRequest req, HttpResponse resp)
public async Task Handle(MongoId sessionId, RequestDelegate next, HttpContext context)
{
switch (req.Method)
switch (context.Request.Method)
{
case "GET":
{
var response = await GetResponse(sessionId, req, null);
await SendResponse(sessionId, req, resp, null, response);
var response = await GetResponse(sessionId, next, context, null);
// Another handler is already handling this, or no handler was found.
if (response is null)
{
return;
}
await SendResponse(sessionId, context.Request, context.Response, null, response);
break;
}
// these are handled almost identically.
@@ -50,20 +57,21 @@ public class SptHttpListener(
// determine if the payload is compressed. All PUT requests are, and POST requests without
// debug = 1 are as well. This should be fixed.
// let compressed = req.headers["content-encoding"] === "deflate";
var requestIsCompressed = !req.Headers.TryGetValue("requestcompressed", out var compressHeader) || compressHeader != "0";
var requestCompressed = req.Method == "PUT" || requestIsCompressed;
var requestIsCompressed =
!context.Request.Headers.TryGetValue("requestcompressed", out var compressHeader) || compressHeader != "0";
var requestCompressed = context.Request.Method == "PUT" || requestIsCompressed;
string body;
if (requestCompressed)
{
await using var deflateStream = new ZLibStream(req.Body, CompressionMode.Decompress);
await using var deflateStream = new ZLibStream(context.Request.Body, CompressionMode.Decompress);
using var reader = new StreamReader(deflateStream, Encoding.UTF8);
body = await reader.ReadToEndAsync();
}
else
{
using var reader = new StreamReader(req.Body, Encoding.UTF8);
using var reader = new StreamReader(context.Request.Body, Encoding.UTF8);
body = await reader.ReadToEndAsync();
}
@@ -75,14 +83,15 @@ public class SptHttpListener(
}
}
var response = await GetResponse(sessionId, req, body);
await SendResponse(sessionId, req, resp, body, response);
break;
var response = await GetResponse(sessionId, next, context, body);
// Another handler is already handling this, or no handler was found.
if (response is null)
{
return;
}
default:
{
logger.Warning($"{serverLocalisationService.GetText("unknown_request")}: {req.Method}");
await SendResponse(sessionId, context.Request, context.Response, body, response);
break;
}
}
@@ -154,24 +163,24 @@ public class SptHttpListener(
}
}
public async ValueTask<string> GetResponse(MongoId sessionId, HttpRequest req, string? body)
public async ValueTask<string?> GetResponse(MongoId sessionId, RequestDelegate next, HttpContext context, string? body)
{
var output = await httpRouter.GetResponse(req, sessionId, body);
// Route doesn't exist or response is not properly set up
if (string.IsNullOrEmpty(output))
{
logger.Error(serverLocalisationService.GetText("unhandled_response", req.Path.ToString()));
output = httpResponseUtil.GetBody<object?>(null, BackendErrorCodes.HTTPNotFound, $"UNHANDLED RESPONSE: {req.Path.ToString()}");
}
var output = await httpRouter.GetResponse(context.Request, sessionId, body);
if (ProgramStatics.ENTRY_TYPE() != EntryType.RELEASE)
{
// Parse quest info into object
var log = new Request(req.Method, new RequestData(req.Path.ToString(), req.Headers));
var log = new Request(context.Request.Method, new RequestData(context.Request.Path.ToString(), context.Request.Headers));
requestsLogger.Info($"REQUEST={jsonUtil.Serialize(log)}");
}
// Route doesn't exist or response is not properly set up, continue to next handlers.
if (string.IsNullOrEmpty(output))
{
await next(context);
return null;
}
return output;
}
@@ -1,10 +1,7 @@
using System.Net;
using System.Net.Sockets;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http;
using SPTarkov.DI.Annotations;
using SPTarkov.Server.Core.Models.Common;
using SPTarkov.Server.Core.Models.Spt.Config;
using SPTarkov.Server.Core.Models.Utils;
using SPTarkov.Server.Core.Servers.Http;
using SPTarkov.Server.Core.Services;
@@ -12,8 +9,6 @@ namespace SPTarkov.Server.Core.Servers;
[Injectable(InjectionType.Singleton)]
public class HttpServer(
ISptLogger<HttpServer> logger,
ServerLocalisationService serverLocalisationService,
ConfigServer configServer,
WebSocketServer webSocketServer,
ProfileActivityService profileActivityService,
@@ -22,9 +17,9 @@ public class HttpServer(
{
protected readonly HttpConfig HttpConfig = configServer.GetConfig<HttpConfig>();
public async Task HandleRequest(HttpContext context)
public async Task HandleRequest(HttpContext context, RequestDelegate next)
{
if (context.WebSockets.IsWebSocketRequest)
if (context.WebSockets.IsWebSocketRequest && webSocketServer.CanHandle(context))
{
await webSocketServer.OnConnection(context);
return;
@@ -34,116 +29,24 @@ public class HttpServer(
var sessionId = context.Request.Cookies.TryGetValue("PHPSESSID", out var sessionIdString)
? new MongoId(sessionIdString)
: MongoId.Empty();
if (!string.IsNullOrEmpty(sessionIdString))
{
profileActivityService.SetActivityTimestamp(sessionId);
}
var realIp = context.Connection.RemoteIpAddress ?? IPAddress.Parse("127.0.0.1");
if (HttpConfig.LogRequests)
{
LogRequest(context, realIp, IsPrivateOrLocalAddress(realIp));
}
try
{
var listener = httpListeners.FirstOrDefault(listener => listener.CanHandle(sessionId, context.Request));
if (listener != null)
{
await listener.Handle(sessionId, context.Request, context.Response);
}
}
catch (Exception ex)
{
logger.Critical("Error handling request: " + context.Request.Path);
logger.Critical(ex.Message);
logger.Critical(ex.StackTrace);
#if DEBUG
throw; // added this so we can debug something.
#endif
}
// This http request would be passed through the SPT Router and handled by an ICallback
}
/// <summary>
/// Log request - handle differently if request is local
/// </summary>
/// <param name="context">HttpContext of request</param>
/// <param name="clientIp">Ip of requester</param>
/// <param name="isLocalRequest">Is this local request</param>
protected void LogRequest(HttpContext context, IPAddress clientIp, bool isLocalRequest)
{
if (isLocalRequest)
{
logger.Info(serverLocalisationService.GetText("client_request", context.Request.Path.Value));
await listener.Handle(sessionId, next, context);
}
else
{
logger.Info(serverLocalisationService.GetText("client_request_ip", new { ip = clientIp, url = context.Request.Path.Value }));
await next(context);
}
}
/// <summary>
/// Check against hardcoded values that determine it's from a local address
/// </summary>
/// <param name="remoteAddress"> Address to check </param>
/// <returns> True if its local </returns>
protected bool IsPrivateOrLocalAddress(IPAddress remoteAddress)
{
if (IPAddress.IsLoopback(remoteAddress))
{
return true;
}
if (remoteAddress.AddressFamily == AddressFamily.InterNetwork)
{
var bytes = remoteAddress.GetAddressBytes();
switch (bytes[0])
{
case 10:
return true; // 10.0.0.0/8 (private)
case 169:
return bytes[1] == 254; // 169.254.0.0/16 (APIPA/link-local)
case 172:
return bytes[1] >= 16 && bytes[1] <= 31; // 172.16.0.0/12 (private)
case 192:
return bytes[1] == 168; // 192.168.0.0/16 (private)
default:
return false;
}
}
if (remoteAddress.AddressFamily == AddressFamily.InterNetworkV6)
{
if (remoteAddress.IsIPv6LinkLocal)
{
return true;
}
}
return false;
}
protected Dictionary<string, string> GetCookies(HttpRequest req)
{
var found = new Dictionary<string, string>();
foreach (var keyValuePair in req.Cookies)
{
found.Add(keyValuePair.Key, keyValuePair.Value);
}
return found;
}
public string ListeningUrl()
{
return $"https://{HttpConfig.Ip}:{HttpConfig.Port}";
@@ -10,6 +10,11 @@ namespace SPTarkov.Server.Core.Servers;
[Injectable(InjectionType.Singleton)]
public class WebSocketServer(IEnumerable<IWebSocketConnectionHandler> webSocketConnectionHandler, ISptLogger<WebSocketServer> logger)
{
public bool CanHandle(HttpContext context)
{
return webSocketConnectionHandler.Any(wsh => context.Request.Path.Value.Contains(wsh.GetHookUrl()));
}
public async Task OnConnection(HttpContext httpContext)
{
var socket = await httpContext.WebSockets.AcceptWebSocketAsync();
@@ -22,16 +27,6 @@ public class WebSocketServer(IEnumerable<IWebSocketConnectionHandler> webSocketC
var cts = new CancellationTokenSource();
var wsToken = cts.Token;
if (!socketHandlers.Any())
{
var message =
$"Socket connection received for url {context.Request.Path.Value}, but there is no websocket handler configured for it!";
logger.Debug(message);
await webSocket.CloseAsync(WebSocketCloseStatus.ProtocolError, message, CancellationToken.None);
return;
}
var webSocketIdContext = DateTime.UtcNow.ToString("yyyyMMddHHmmssfff");
if (logger.IsLogEnabled(LogLevel.Debug))
@@ -20,8 +20,10 @@ public class StatusPage(TimeUtil timeUtil, ProfileActivityService profileActivit
return req.Method == "GET" && req.Path.Value.Contains("/status");
}
public async Task Handle(MongoId sessionId, HttpRequest req, HttpResponse resp)
public async Task Handle(MongoId sessionId, RequestDelegate next, HttpContext context)
{
var resp = context.Response;
var sptVersion = $"SPT version: {ProgramStatics.SPT_VERSION()}";
var debugEnabled = $"Debug enabled: {ProgramStatics.DEBUG()}";
var modsEnabled = $"Mods enabled: {ProgramStatics.MODS()}";
@@ -0,0 +1,131 @@
using System.Net;
using System.Net.Sockets;
using SPTarkov.Server.Core.Models.Spt.Config;
using SPTarkov.Server.Core.Models.Utils;
using SPTarkov.Server.Core.Servers;
using SPTarkov.Server.Core.Services;
namespace SPTarkov.Server.Logger;
public class SptLoggerMiddleware(
RequestDelegate next,
ServerLocalisationService serverLocalisationService,
ConfigServer configServer,
ISptLogger<SptLoggerMiddleware> logger
)
{
protected readonly HttpConfig HttpConfig = configServer.GetConfig<HttpConfig>();
public async Task InvokeAsync(HttpContext context)
{
if (!HttpConfig.LogRequests)
{
await next(context);
return;
}
var realIp = context.Connection.RemoteIpAddress ?? IPAddress.Parse("127.0.0.1");
LogRequest(context, realIp, IsPrivateOrLocalAddress(realIp), context.WebSockets.IsWebSocketRequest);
try
{
await next(context);
if (context.Response.StatusCode == 404)
{
logger.Error(serverLocalisationService.GetText("unhandled_response", context.Request.Path.ToString()));
}
}
catch (Exception ex)
{
logger.Critical("Error handling request: " + context.Request.Path);
logger.Critical(ex.Message);
logger.Critical(ex.StackTrace);
#if DEBUG
throw; // added this so we can debug something.
#endif
}
}
/// <summary>
/// Log request - handle differently if request is local
/// </summary>
/// <param name="context">HttpContext of request</param>
/// <param name="clientIp">Ip of requester</param>
/// <param name="isLocalRequest">Is this local request</param>
protected void LogRequest(HttpContext context, IPAddress clientIp, bool isLocalRequest, bool isWSRequest)
{
if (isWSRequest)
{
if (isLocalRequest)
{
logger.Info(serverLocalisationService.GetText("websocket_request", context.Request.Path.Value));
}
else
{
logger.Info(
serverLocalisationService.GetText("websocket_request_ip", new { ip = clientIp, url = context.Request.Path.Value })
);
}
}
else
{
if (isLocalRequest)
{
logger.Info(serverLocalisationService.GetText("client_request", context.Request.Path.Value));
}
else
{
logger.Info(
serverLocalisationService.GetText("client_request_ip", new { ip = clientIp, url = context.Request.Path.Value })
);
}
}
}
/// <summary>
/// Check against hardcoded values that determine it's from a local address
/// </summary>
/// <param name="remoteAddress"> Address to check </param>
/// <returns> True if its local </returns>
protected bool IsPrivateOrLocalAddress(IPAddress remoteAddress)
{
if (IPAddress.IsLoopback(remoteAddress))
{
return true;
}
if (remoteAddress.AddressFamily == AddressFamily.InterNetwork)
{
var bytes = remoteAddress.GetAddressBytes();
switch (bytes[0])
{
case 10:
return true; // 10.0.0.0/8 (private)
case 169:
return bytes[1] == 254; // 169.254.0.0/16 (APIPA/link-local)
case 172:
return bytes[1] >= 16 && bytes[1] <= 31; // 172.16.0.0/12 (private)
case 192:
return bytes[1] == 168; // 192.168.0.0/16 (private)
default:
return false;
}
}
if (remoteAddress.AddressFamily == AddressFamily.InterNetworkV6)
{
if (remoteAddress.IsIPv6LinkLocal)
{
return true;
}
}
return false;
}
}
+5 -2
View File
@@ -116,10 +116,13 @@ public static class Program
KeepAliveInterval = TimeSpan.FromSeconds(60),
}
);
app.UseMiddleware<SptLoggerMiddleware>();
app.Use(
async (HttpContext context, RequestDelegate _) =>
async (HttpContext context, RequestDelegate next) =>
{
await context.RequestServices.GetRequiredService<HttpServer>().HandleRequest(context);
await context.RequestServices.GetRequiredService<HttpServer>().HandleRequest(context, next);
}
);
}