DOS patch

This commit is contained in:
Tayrtahn
2026-05-01 16:24:15 -04:00
parent 0e1ed2e86d
commit 598a4ab29a
4 changed files with 248 additions and 17 deletions
@@ -0,0 +1,178 @@
using Lidgren.Network;
using NUnit.Framework;
using Robust.Shared.Network;
namespace Robust.Shared.Tests.Networking;
public sealed class NetEncryptionDoSTest
{
private const ulong Magic = 0x13377777_77777777;
[Test]
[Description("A control test that ensures connecting in a test works.")]
public void ConnectionWorks()
{
var (client, server) = MakeConnectionPair();
var message = client.CreateMessage();
message.WriteVariableUInt64(Magic);
client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
var packet = Receive(server);
Assert.That(packet.ReadVariableUInt64(), Is.EqualTo(Magic));
server.Shutdown(null);
}
[Test]
[Description("A control test that just ensures encryption works as other tests expect.")]
public void EncryptionWorks()
{
var (clientEnc, serverEnc) = MakeEncryptionPair();
var (client, server) = MakeConnectionPair();
var message = client.CreateMessage();
message.WriteVariableUInt64(Magic);
clientEnc.Encrypt(message);
client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
var packet = Receive(server);
Assert.That(serverEnc.TryDecrypt(packet), Is.True);
server.Shutdown(null);
}
[Test]
[Description("Attempt to decrypt a packet that is using the wrong encryption keys, ensuring it doesn't throw.")]
public void WrongKeyFailureDoesNotThrow()
{
var (clientEnc, serverEnc) = MakeEncryptionPair(disjointKey: true);
var (client, server) = MakeConnectionPair();
var message = client.CreateMessage();
message.WriteVariableUInt64(Magic);
clientEnc.Encrypt(message);
client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
var packet = Receive(server);
Assert.That(serverEnc.TryDecrypt(packet), Is.False);
server.Shutdown(null);
}
private static int[] _badMessages =
[
5,
1,
4,
16,
1024,
];
[Test]
[Description("Attempt to decrypt a packet that is bogus, ensuring it doesn't throw.")]
[TestCaseSource(nameof(_badMessages))]
public void BadMessageDoesNotThrow(int badMessageLength)
{
var badMessage = new byte[badMessageLength];
System.Random.Shared.NextBytes(badMessage);
var (_, serverEnc) = MakeEncryptionPair(disjointKey: true);
var (client, server) = MakeConnectionPair();
var message = client.CreateMessage();
message.Write(badMessage);
// Don't encrypt at all.
client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
var packet = Receive(server);
Assert.That(packet.LengthBytes, Is.EqualTo(badMessageLength));
Assert.That(serverEnc.TryDecrypt(packet), Is.False);
server.Shutdown(null);
}
// TODO: Generalize all this for other low level network tests.
private (NetClient client, NetServer server) MakeConnectionPair()
{
const string id = "test";
var client = new NetClient(new NetPeerConfiguration(id));
var server = new NetServer( new NetPeerConfiguration(id));
client.Start();
// Lidgren has no facilities for mocking this nicely.
// So we just use an actual socket.
server.Start();
client.Connect("localhost", server.Port);
var ready = false;
while (!ready)
{
switch (server.WaitMessage(1000))
{
case { MessageType: NetIncomingMessageType.StatusChanged } msg:
{
// hello there.
var status = (NetConnectionStatus)msg.ReadByte();
if (status == NetConnectionStatus.Connected)
ready = true;
break;
}
}
}
return (client, server);
}
private NetIncomingMessage Receive(NetPeer peer)
{
NetIncomingMessage? found = null;
while (found == null)
{
switch (peer.WaitMessage(1000))
{
case { MessageType: NetIncomingMessageType.Data } msg:
{
found = msg;
break;
}
}
}
return found;
}
private (NetEncryption client, NetEncryption server) MakeEncryptionPair(bool disjointKey = false)
{
var serverKey = new byte[32];
System.Random.Shared.NextBytes(serverKey.AsSpan());
var clientKey = (byte[])serverKey.Clone();
if (disjointKey)
System.Random.Shared.NextBytes(clientKey.AsSpan());
return (new NetEncryption(clientKey, false), new NetEncryption(serverKey, true));
}
}
+12 -3
View File
@@ -84,8 +84,18 @@ internal sealed class NetEncryption
ArrayPool<byte>.Shared.Return(returnPool);
}
public unsafe void Decrypt(NetIncomingMessage message)
/// <summary>
/// Attempts to decrypt an incoming network message, falliably.
/// </summary>
/// <param name="message">The message to decrypt in-place. This will be mutated with the decrypted results.</param>
/// <returns>Whether the operation was successful. If this fails, you likely want to drop the connection.</returns>
public unsafe bool TryDecrypt(NetIncomingMessage message)
{
// Minimum possible size a message can be is the nonce + 16 bytes of message.
// So we immediately bail on anything smaller.
if (message.LengthBytes < sizeof(ulong) + CryptoAeadXChaCha20Poly1305Ietf.AddBytes)
return false;
var nonce = message.ReadUInt64();
var cipherText = message.Data.AsSpan(sizeof(ulong), message.LengthBytes - sizeof(ulong));
@@ -114,7 +124,6 @@ internal sealed class NetEncryption
ArrayPool<byte>.Shared.Return(buffer);
if (!result)
throw new SodiumException("Decryption operation failed!");
return result;
}
}
@@ -220,7 +220,14 @@ namespace Robust.Shared.Network
// Expect login success here.
response = await AwaitData(connection, cancel);
encryption?.Decrypt(response);
// Attempt to decrypt the message, only logging if we fail to decrypt and we actually have encryption.
if ((!encryption?.TryDecrypt(response)) ?? false)
{
const string msg = "Failed to decrypt login success.";
connection.Disconnect(msg);
throw new Exception(msg);
}
}
var msgSuc = new MsgLoginSuccess();
+50 -13
View File
@@ -113,6 +113,11 @@ namespace Robust.Shared.Network
[Dependency] private readonly HttpClientHolder _http = default!;
[Dependency] private readonly IHWId _hwId = default!;
/// <summary>
/// Whether we bother to log problematic packets. Set by <see cref="CVars.NetLogging"/>.
/// </summary>
private bool _logPacketIssues = false;
/// <summary>
/// Holds lookup table for NetMessage.Id -> NetMessage.Type
/// </summary>
@@ -258,6 +263,7 @@ namespace Robust.Shared.Network
_config.OnValueChanged(CVars.NetLidgrenLogError, LidgrenLogErrorChanged);
_config.OnValueChanged(CVars.NetVerbose, NetVerboseChanged);
_config.OnValueChanged(CVars.NetLogging, NetLoggingChanged);
if (isServer)
{
_config.OnValueChanged(CVars.AuthMode, OnAuthModeChanged, invokeImmediately: true);
@@ -280,6 +286,11 @@ namespace Robust.Shared.Network
}
}
private void NetLoggingChanged(bool obj)
{
_logPacketIssues = obj;
}
private void LidgrenLogWarningChanged(bool newValue)
{
foreach (var netPeer in _netPeers)
@@ -870,19 +881,25 @@ namespace Robust.Shared.Network
var peer = msg.SenderConnection.Peer;
if (peer.Status == NetPeerStatus.ShutdownRequested)
{
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but shutdown is requested.");
if (_logPacketIssues)
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but shutdown is requested.");
return true;
}
if (peer.Status == NetPeerStatus.NotRunning)
{
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, peer is not running.");
if (_logPacketIssues)
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, peer is not running.");
return true;
}
if (!IsConnected)
{
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but not connected.");
if (_logPacketIssues)
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but not connected.");
return true;
}
@@ -897,19 +914,33 @@ namespace Robust.Shared.Network
if (msg.LengthBytes < 1)
{
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received empty packet.");
if (_logPacketIssues)
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received empty packet.");
msg.SenderConnection.Disconnect("Received empty/weird packet", false);
return true;
}
if (!_channels.TryGetValue(msg.SenderConnection, out var channel))
{
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got unexpected data packet before handshake completion.");
if (_logPacketIssues)
_logger.Debug($"{msg.SenderConnection.RemoteEndPoint}: Got unexpected data packet before handshake completion.");
msg.SenderConnection.Disconnect("Unexpected packet before handshake completion");
msg.SenderConnection.Disconnect("Unexpected packet before handshake completion", false);
return true;
}
channel.Encryption?.Decrypt(msg);
// Attempt to decrypt the message, only logging if we fail to decrypt and we actually have encryption.
if ((!channel.Encryption?.TryDecrypt(msg)) ?? false)
{
if (_logPacketIssues)
_logger.Debug($"{msg.SenderConnection.RemoteEndPoint}: Got a packet that fails to decrypt.");
msg.SenderConnection.Disconnect("Failed to decrypt packet.", false);
return true;
}
var id = msg.ReadByte();
@@ -917,9 +948,10 @@ namespace Robust.Shared.Network
if (entry == null)
{
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got net message with invalid ID {id}.");
if (_logPacketIssues)
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got net message with invalid ID {id}.");
channel.Disconnect("Got NetMessage with invalid ID");
channel.Disconnect("Got NetMessage with invalid ID", false);
return true;
}
@@ -927,9 +959,10 @@ namespace Robust.Shared.Network
if (!channel.IsHandshakeComplete && !entry.IsHandshake)
{
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got non-handshake message {entry.Type.Name} before handshake completion.");
if (_logPacketIssues)
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got non-handshake message {entry.Type.Name} before handshake completion.");
channel.Disconnect("Got unacceptable net message before handshake completion");
channel.Disconnect("Got unacceptable net message before handshake completion", false);
return true;
}
@@ -955,12 +988,16 @@ namespace Robust.Shared.Network
}
catch (InvalidCastException ice)
{
_logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Wrong deserialization of {type.Name} packet:\n{ice}");
if (_logPacketIssues)
_logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Wrong deserialization of {type.Name} packet:\n{ice}");
channel.Disconnect("Failed to deserialize packet.", false);
return true;
}
catch (Exception e) // yes, we want to catch ALL exeptions for security
{
_logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Failed to deserialize {type.Name} packet:\n{e}");
if (_logPacketIssues)
_logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Failed to deserialize {type.Name} packet:\n{e}");
channel.Disconnect("Failed to deserialize packet.", false);
return true;
}