Http server refactor (#593)
* Refactor various parts of the HttpListener to support Blazor loading * Add logging for WebSocket requests * Add better handling for WebSockets not belonging to SPT * Remove unecessary check * Remove check as it's already handled earlier now * Cleanup * Set delegate
This commit is contained in:
@@ -62,6 +62,8 @@
|
||||
"chat-unable_to_register_command_already_registered": "Unable to register already registered command: %s",
|
||||
"client_request": "[Client Request] %s",
|
||||
"client_request_ip": "[Client Request] {{ip}} {{url}}",
|
||||
"websocket_request": "[WebSocket Request] %s",
|
||||
"websocket_request_ip": "[WebSocket Request] {{ip}} {{url}}",
|
||||
"custom-quest-service_quest_id_already_exists": "A quest with the id: {{questId}} already exists.",
|
||||
"custom-quest-service_no_languages_for_quest": "No languages have been added for custom quest id: {{questId}}",
|
||||
"custom-quest-service_no_entries_for_language": "No locale entries have been added for language key: {{languageKey}}, was this intentional?",
|
||||
|
||||
@@ -6,5 +6,5 @@ namespace SPTarkov.Server.Core.Servers.Http;
|
||||
public interface IHttpListener
|
||||
{
|
||||
bool CanHandle(MongoId sessionId, HttpRequest req);
|
||||
Task Handle(MongoId sessionId, HttpRequest req, HttpResponse resp);
|
||||
Task Handle(MongoId sessionId, RequestDelegate next, HttpContext context);
|
||||
}
|
||||
|
||||
@@ -32,14 +32,21 @@ public class SptHttpListener(
|
||||
return SupportedMethods.Contains(req.Method);
|
||||
}
|
||||
|
||||
public async Task Handle(MongoId sessionId, HttpRequest req, HttpResponse resp)
|
||||
public async Task Handle(MongoId sessionId, RequestDelegate next, HttpContext context)
|
||||
{
|
||||
switch (req.Method)
|
||||
switch (context.Request.Method)
|
||||
{
|
||||
case "GET":
|
||||
{
|
||||
var response = await GetResponse(sessionId, req, null);
|
||||
await SendResponse(sessionId, req, resp, null, response);
|
||||
var response = await GetResponse(sessionId, next, context, null);
|
||||
|
||||
// Another handler is already handling this, or no handler was found.
|
||||
if (response is null)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
await SendResponse(sessionId, context.Request, context.Response, null, response);
|
||||
break;
|
||||
}
|
||||
// these are handled almost identically.
|
||||
@@ -50,20 +57,21 @@ public class SptHttpListener(
|
||||
// determine if the payload is compressed. All PUT requests are, and POST requests without
|
||||
// debug = 1 are as well. This should be fixed.
|
||||
// let compressed = req.headers["content-encoding"] === "deflate";
|
||||
var requestIsCompressed = !req.Headers.TryGetValue("requestcompressed", out var compressHeader) || compressHeader != "0";
|
||||
var requestCompressed = req.Method == "PUT" || requestIsCompressed;
|
||||
var requestIsCompressed =
|
||||
!context.Request.Headers.TryGetValue("requestcompressed", out var compressHeader) || compressHeader != "0";
|
||||
var requestCompressed = context.Request.Method == "PUT" || requestIsCompressed;
|
||||
|
||||
string body;
|
||||
|
||||
if (requestCompressed)
|
||||
{
|
||||
await using var deflateStream = new ZLibStream(req.Body, CompressionMode.Decompress);
|
||||
await using var deflateStream = new ZLibStream(context.Request.Body, CompressionMode.Decompress);
|
||||
using var reader = new StreamReader(deflateStream, Encoding.UTF8);
|
||||
body = await reader.ReadToEndAsync();
|
||||
}
|
||||
else
|
||||
{
|
||||
using var reader = new StreamReader(req.Body, Encoding.UTF8);
|
||||
using var reader = new StreamReader(context.Request.Body, Encoding.UTF8);
|
||||
body = await reader.ReadToEndAsync();
|
||||
}
|
||||
|
||||
@@ -75,14 +83,15 @@ public class SptHttpListener(
|
||||
}
|
||||
}
|
||||
|
||||
var response = await GetResponse(sessionId, req, body);
|
||||
await SendResponse(sessionId, req, resp, body, response);
|
||||
break;
|
||||
var response = await GetResponse(sessionId, next, context, body);
|
||||
|
||||
// Another handler is already handling this, or no handler was found.
|
||||
if (response is null)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
default:
|
||||
{
|
||||
logger.Warning($"{serverLocalisationService.GetText("unknown_request")}: {req.Method}");
|
||||
await SendResponse(sessionId, context.Request, context.Response, body, response);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -154,24 +163,24 @@ public class SptHttpListener(
|
||||
}
|
||||
}
|
||||
|
||||
public async ValueTask<string> GetResponse(MongoId sessionId, HttpRequest req, string? body)
|
||||
public async ValueTask<string?> GetResponse(MongoId sessionId, RequestDelegate next, HttpContext context, string? body)
|
||||
{
|
||||
var output = await httpRouter.GetResponse(req, sessionId, body);
|
||||
|
||||
// Route doesn't exist or response is not properly set up
|
||||
if (string.IsNullOrEmpty(output))
|
||||
{
|
||||
logger.Error(serverLocalisationService.GetText("unhandled_response", req.Path.ToString()));
|
||||
output = httpResponseUtil.GetBody<object?>(null, BackendErrorCodes.HTTPNotFound, $"UNHANDLED RESPONSE: {req.Path.ToString()}");
|
||||
}
|
||||
var output = await httpRouter.GetResponse(context.Request, sessionId, body);
|
||||
|
||||
if (ProgramStatics.ENTRY_TYPE() != EntryType.RELEASE)
|
||||
{
|
||||
// Parse quest info into object
|
||||
var log = new Request(req.Method, new RequestData(req.Path.ToString(), req.Headers));
|
||||
var log = new Request(context.Request.Method, new RequestData(context.Request.Path.ToString(), context.Request.Headers));
|
||||
requestsLogger.Info($"REQUEST={jsonUtil.Serialize(log)}");
|
||||
}
|
||||
|
||||
// Route doesn't exist or response is not properly set up, continue to next handlers.
|
||||
if (string.IsNullOrEmpty(output))
|
||||
{
|
||||
await next(context);
|
||||
return null;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using SPTarkov.DI.Annotations;
|
||||
using SPTarkov.Server.Core.Models.Common;
|
||||
using SPTarkov.Server.Core.Models.Spt.Config;
|
||||
using SPTarkov.Server.Core.Models.Utils;
|
||||
using SPTarkov.Server.Core.Servers.Http;
|
||||
using SPTarkov.Server.Core.Services;
|
||||
|
||||
@@ -12,8 +9,6 @@ namespace SPTarkov.Server.Core.Servers;
|
||||
|
||||
[Injectable(InjectionType.Singleton)]
|
||||
public class HttpServer(
|
||||
ISptLogger<HttpServer> logger,
|
||||
ServerLocalisationService serverLocalisationService,
|
||||
ConfigServer configServer,
|
||||
WebSocketServer webSocketServer,
|
||||
ProfileActivityService profileActivityService,
|
||||
@@ -22,9 +17,9 @@ public class HttpServer(
|
||||
{
|
||||
protected readonly HttpConfig HttpConfig = configServer.GetConfig<HttpConfig>();
|
||||
|
||||
public async Task HandleRequest(HttpContext context)
|
||||
public async Task HandleRequest(HttpContext context, RequestDelegate next)
|
||||
{
|
||||
if (context.WebSockets.IsWebSocketRequest)
|
||||
if (context.WebSockets.IsWebSocketRequest && webSocketServer.CanHandle(context))
|
||||
{
|
||||
await webSocketServer.OnConnection(context);
|
||||
return;
|
||||
@@ -34,116 +29,24 @@ public class HttpServer(
|
||||
var sessionId = context.Request.Cookies.TryGetValue("PHPSESSID", out var sessionIdString)
|
||||
? new MongoId(sessionIdString)
|
||||
: MongoId.Empty();
|
||||
|
||||
if (!string.IsNullOrEmpty(sessionIdString))
|
||||
{
|
||||
profileActivityService.SetActivityTimestamp(sessionId);
|
||||
}
|
||||
|
||||
var realIp = context.Connection.RemoteIpAddress ?? IPAddress.Parse("127.0.0.1");
|
||||
|
||||
if (HttpConfig.LogRequests)
|
||||
{
|
||||
LogRequest(context, realIp, IsPrivateOrLocalAddress(realIp));
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var listener = httpListeners.FirstOrDefault(listener => listener.CanHandle(sessionId, context.Request));
|
||||
|
||||
if (listener != null)
|
||||
{
|
||||
await listener.Handle(sessionId, context.Request, context.Response);
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
logger.Critical("Error handling request: " + context.Request.Path);
|
||||
logger.Critical(ex.Message);
|
||||
logger.Critical(ex.StackTrace);
|
||||
#if DEBUG
|
||||
throw; // added this so we can debug something.
|
||||
#endif
|
||||
}
|
||||
|
||||
// This http request would be passed through the SPT Router and handled by an ICallback
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Log request - handle differently if request is local
|
||||
/// </summary>
|
||||
/// <param name="context">HttpContext of request</param>
|
||||
/// <param name="clientIp">Ip of requester</param>
|
||||
/// <param name="isLocalRequest">Is this local request</param>
|
||||
protected void LogRequest(HttpContext context, IPAddress clientIp, bool isLocalRequest)
|
||||
{
|
||||
if (isLocalRequest)
|
||||
{
|
||||
logger.Info(serverLocalisationService.GetText("client_request", context.Request.Path.Value));
|
||||
await listener.Handle(sessionId, next, context);
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.Info(serverLocalisationService.GetText("client_request_ip", new { ip = clientIp, url = context.Request.Path.Value }));
|
||||
await next(context);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Check against hardcoded values that determine it's from a local address
|
||||
/// </summary>
|
||||
/// <param name="remoteAddress"> Address to check </param>
|
||||
/// <returns> True if its local </returns>
|
||||
protected bool IsPrivateOrLocalAddress(IPAddress remoteAddress)
|
||||
{
|
||||
if (IPAddress.IsLoopback(remoteAddress))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if (remoteAddress.AddressFamily == AddressFamily.InterNetwork)
|
||||
{
|
||||
var bytes = remoteAddress.GetAddressBytes();
|
||||
|
||||
switch (bytes[0])
|
||||
{
|
||||
case 10:
|
||||
return true; // 10.0.0.0/8 (private)
|
||||
|
||||
case 169:
|
||||
return bytes[1] == 254; // 169.254.0.0/16 (APIPA/link-local)
|
||||
|
||||
case 172:
|
||||
return bytes[1] >= 16 && bytes[1] <= 31; // 172.16.0.0/12 (private)
|
||||
|
||||
case 192:
|
||||
return bytes[1] == 168; // 192.168.0.0/16 (private)
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (remoteAddress.AddressFamily == AddressFamily.InterNetworkV6)
|
||||
{
|
||||
if (remoteAddress.IsIPv6LinkLocal)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
protected Dictionary<string, string> GetCookies(HttpRequest req)
|
||||
{
|
||||
var found = new Dictionary<string, string>();
|
||||
|
||||
foreach (var keyValuePair in req.Cookies)
|
||||
{
|
||||
found.Add(keyValuePair.Key, keyValuePair.Value);
|
||||
}
|
||||
|
||||
return found;
|
||||
}
|
||||
|
||||
public string ListeningUrl()
|
||||
{
|
||||
return $"https://{HttpConfig.Ip}:{HttpConfig.Port}";
|
||||
|
||||
@@ -10,6 +10,11 @@ namespace SPTarkov.Server.Core.Servers;
|
||||
[Injectable(InjectionType.Singleton)]
|
||||
public class WebSocketServer(IEnumerable<IWebSocketConnectionHandler> webSocketConnectionHandler, ISptLogger<WebSocketServer> logger)
|
||||
{
|
||||
public bool CanHandle(HttpContext context)
|
||||
{
|
||||
return webSocketConnectionHandler.Any(wsh => context.Request.Path.Value.Contains(wsh.GetHookUrl()));
|
||||
}
|
||||
|
||||
public async Task OnConnection(HttpContext httpContext)
|
||||
{
|
||||
var socket = await httpContext.WebSockets.AcceptWebSocketAsync();
|
||||
@@ -22,16 +27,6 @@ public class WebSocketServer(IEnumerable<IWebSocketConnectionHandler> webSocketC
|
||||
|
||||
var cts = new CancellationTokenSource();
|
||||
var wsToken = cts.Token;
|
||||
|
||||
if (!socketHandlers.Any())
|
||||
{
|
||||
var message =
|
||||
$"Socket connection received for url {context.Request.Path.Value}, but there is no websocket handler configured for it!";
|
||||
logger.Debug(message);
|
||||
await webSocket.CloseAsync(WebSocketCloseStatus.ProtocolError, message, CancellationToken.None);
|
||||
return;
|
||||
}
|
||||
|
||||
var webSocketIdContext = DateTime.UtcNow.ToString("yyyyMMddHHmmssfff");
|
||||
|
||||
if (logger.IsLogEnabled(LogLevel.Debug))
|
||||
|
||||
@@ -20,8 +20,10 @@ public class StatusPage(TimeUtil timeUtil, ProfileActivityService profileActivit
|
||||
return req.Method == "GET" && req.Path.Value.Contains("/status");
|
||||
}
|
||||
|
||||
public async Task Handle(MongoId sessionId, HttpRequest req, HttpResponse resp)
|
||||
public async Task Handle(MongoId sessionId, RequestDelegate next, HttpContext context)
|
||||
{
|
||||
var resp = context.Response;
|
||||
|
||||
var sptVersion = $"SPT version: {ProgramStatics.SPT_VERSION()}";
|
||||
var debugEnabled = $"Debug enabled: {ProgramStatics.DEBUG()}";
|
||||
var modsEnabled = $"Mods enabled: {ProgramStatics.MODS()}";
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using SPTarkov.Server.Core.Models.Spt.Config;
|
||||
using SPTarkov.Server.Core.Models.Utils;
|
||||
using SPTarkov.Server.Core.Servers;
|
||||
using SPTarkov.Server.Core.Services;
|
||||
|
||||
namespace SPTarkov.Server.Logger;
|
||||
|
||||
public class SptLoggerMiddleware(
|
||||
RequestDelegate next,
|
||||
ServerLocalisationService serverLocalisationService,
|
||||
ConfigServer configServer,
|
||||
ISptLogger<SptLoggerMiddleware> logger
|
||||
)
|
||||
{
|
||||
protected readonly HttpConfig HttpConfig = configServer.GetConfig<HttpConfig>();
|
||||
|
||||
public async Task InvokeAsync(HttpContext context)
|
||||
{
|
||||
if (!HttpConfig.LogRequests)
|
||||
{
|
||||
await next(context);
|
||||
return;
|
||||
}
|
||||
|
||||
var realIp = context.Connection.RemoteIpAddress ?? IPAddress.Parse("127.0.0.1");
|
||||
LogRequest(context, realIp, IsPrivateOrLocalAddress(realIp), context.WebSockets.IsWebSocketRequest);
|
||||
|
||||
try
|
||||
{
|
||||
await next(context);
|
||||
|
||||
if (context.Response.StatusCode == 404)
|
||||
{
|
||||
logger.Error(serverLocalisationService.GetText("unhandled_response", context.Request.Path.ToString()));
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
logger.Critical("Error handling request: " + context.Request.Path);
|
||||
logger.Critical(ex.Message);
|
||||
logger.Critical(ex.StackTrace);
|
||||
#if DEBUG
|
||||
throw; // added this so we can debug something.
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Log request - handle differently if request is local
|
||||
/// </summary>
|
||||
/// <param name="context">HttpContext of request</param>
|
||||
/// <param name="clientIp">Ip of requester</param>
|
||||
/// <param name="isLocalRequest">Is this local request</param>
|
||||
protected void LogRequest(HttpContext context, IPAddress clientIp, bool isLocalRequest, bool isWSRequest)
|
||||
{
|
||||
if (isWSRequest)
|
||||
{
|
||||
if (isLocalRequest)
|
||||
{
|
||||
logger.Info(serverLocalisationService.GetText("websocket_request", context.Request.Path.Value));
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.Info(
|
||||
serverLocalisationService.GetText("websocket_request_ip", new { ip = clientIp, url = context.Request.Path.Value })
|
||||
);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (isLocalRequest)
|
||||
{
|
||||
logger.Info(serverLocalisationService.GetText("client_request", context.Request.Path.Value));
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.Info(
|
||||
serverLocalisationService.GetText("client_request_ip", new { ip = clientIp, url = context.Request.Path.Value })
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Check against hardcoded values that determine it's from a local address
|
||||
/// </summary>
|
||||
/// <param name="remoteAddress"> Address to check </param>
|
||||
/// <returns> True if its local </returns>
|
||||
protected bool IsPrivateOrLocalAddress(IPAddress remoteAddress)
|
||||
{
|
||||
if (IPAddress.IsLoopback(remoteAddress))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if (remoteAddress.AddressFamily == AddressFamily.InterNetwork)
|
||||
{
|
||||
var bytes = remoteAddress.GetAddressBytes();
|
||||
|
||||
switch (bytes[0])
|
||||
{
|
||||
case 10:
|
||||
return true; // 10.0.0.0/8 (private)
|
||||
|
||||
case 169:
|
||||
return bytes[1] == 254; // 169.254.0.0/16 (APIPA/link-local)
|
||||
|
||||
case 172:
|
||||
return bytes[1] >= 16 && bytes[1] <= 31; // 172.16.0.0/12 (private)
|
||||
|
||||
case 192:
|
||||
return bytes[1] == 168; // 192.168.0.0/16 (private)
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (remoteAddress.AddressFamily == AddressFamily.InterNetworkV6)
|
||||
{
|
||||
if (remoteAddress.IsIPv6LinkLocal)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -116,10 +116,13 @@ public static class Program
|
||||
KeepAliveInterval = TimeSpan.FromSeconds(60),
|
||||
}
|
||||
);
|
||||
|
||||
app.UseMiddleware<SptLoggerMiddleware>();
|
||||
|
||||
app.Use(
|
||||
async (HttpContext context, RequestDelegate _) =>
|
||||
async (HttpContext context, RequestDelegate next) =>
|
||||
{
|
||||
await context.RequestServices.GetRequiredService<HttpServer>().HandleRequest(context);
|
||||
await context.RequestServices.GetRequiredService<HttpServer>().HandleRequest(context, next);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user