mirror of
https://github.com/space-wizards/RobustToolbox.git
synced 2026-06-09 10:06:34 +02:00
DOS patch
This commit is contained in:
@@ -0,0 +1,178 @@
|
||||
using Lidgren.Network;
|
||||
using NUnit.Framework;
|
||||
using Robust.Shared.Network;
|
||||
|
||||
namespace Robust.Shared.Tests.Networking;
|
||||
|
||||
public sealed class NetEncryptionDoSTest
|
||||
{
|
||||
private const ulong Magic = 0x13377777_77777777;
|
||||
|
||||
[Test]
|
||||
[Description("A control test that ensures connecting in a test works.")]
|
||||
public void ConnectionWorks()
|
||||
{
|
||||
var (client, server) = MakeConnectionPair();
|
||||
|
||||
var message = client.CreateMessage();
|
||||
|
||||
message.WriteVariableUInt64(Magic);
|
||||
|
||||
client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
|
||||
|
||||
var packet = Receive(server);
|
||||
|
||||
Assert.That(packet.ReadVariableUInt64(), Is.EqualTo(Magic));
|
||||
server.Shutdown(null);
|
||||
}
|
||||
|
||||
[Test]
|
||||
[Description("A control test that just ensures encryption works as other tests expect.")]
|
||||
public void EncryptionWorks()
|
||||
{
|
||||
var (clientEnc, serverEnc) = MakeEncryptionPair();
|
||||
var (client, server) = MakeConnectionPair();
|
||||
|
||||
var message = client.CreateMessage();
|
||||
|
||||
message.WriteVariableUInt64(Magic);
|
||||
|
||||
clientEnc.Encrypt(message);
|
||||
|
||||
client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
|
||||
|
||||
var packet = Receive(server);
|
||||
|
||||
Assert.That(serverEnc.TryDecrypt(packet), Is.True);
|
||||
server.Shutdown(null);
|
||||
}
|
||||
|
||||
[Test]
|
||||
[Description("Attempt to decrypt a packet that is using the wrong encryption keys, ensuring it doesn't throw.")]
|
||||
public void WrongKeyFailureDoesNotThrow()
|
||||
{
|
||||
var (clientEnc, serverEnc) = MakeEncryptionPair(disjointKey: true);
|
||||
var (client, server) = MakeConnectionPair();
|
||||
|
||||
var message = client.CreateMessage();
|
||||
|
||||
message.WriteVariableUInt64(Magic);
|
||||
|
||||
clientEnc.Encrypt(message);
|
||||
|
||||
client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
|
||||
|
||||
var packet = Receive(server);
|
||||
|
||||
Assert.That(serverEnc.TryDecrypt(packet), Is.False);
|
||||
server.Shutdown(null);
|
||||
}
|
||||
|
||||
private static int[] _badMessages =
|
||||
[
|
||||
5,
|
||||
1,
|
||||
4,
|
||||
16,
|
||||
1024,
|
||||
];
|
||||
|
||||
[Test]
|
||||
[Description("Attempt to decrypt a packet that is bogus, ensuring it doesn't throw.")]
|
||||
[TestCaseSource(nameof(_badMessages))]
|
||||
public void BadMessageDoesNotThrow(int badMessageLength)
|
||||
{
|
||||
var badMessage = new byte[badMessageLength];
|
||||
System.Random.Shared.NextBytes(badMessage);
|
||||
var (_, serverEnc) = MakeEncryptionPair(disjointKey: true);
|
||||
var (client, server) = MakeConnectionPair();
|
||||
|
||||
var message = client.CreateMessage();
|
||||
|
||||
message.Write(badMessage);
|
||||
|
||||
// Don't encrypt at all.
|
||||
|
||||
client.SendMessage(message, NetDeliveryMethod.ReliableOrdered);
|
||||
|
||||
var packet = Receive(server);
|
||||
|
||||
Assert.That(packet.LengthBytes, Is.EqualTo(badMessageLength));
|
||||
|
||||
Assert.That(serverEnc.TryDecrypt(packet), Is.False);
|
||||
|
||||
server.Shutdown(null);
|
||||
}
|
||||
|
||||
|
||||
// TODO: Generalize all this for other low level network tests.
|
||||
|
||||
private (NetClient client, NetServer server) MakeConnectionPair()
|
||||
{
|
||||
const string id = "test";
|
||||
var client = new NetClient(new NetPeerConfiguration(id));
|
||||
|
||||
var server = new NetServer( new NetPeerConfiguration(id));
|
||||
|
||||
client.Start();
|
||||
// Lidgren has no facilities for mocking this nicely.
|
||||
// So we just use an actual socket.
|
||||
server.Start();
|
||||
|
||||
client.Connect("localhost", server.Port);
|
||||
|
||||
var ready = false;
|
||||
|
||||
while (!ready)
|
||||
{
|
||||
switch (server.WaitMessage(1000))
|
||||
{
|
||||
case { MessageType: NetIncomingMessageType.StatusChanged } msg:
|
||||
{
|
||||
// hello there.
|
||||
var status = (NetConnectionStatus)msg.ReadByte();
|
||||
|
||||
if (status == NetConnectionStatus.Connected)
|
||||
ready = true;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (client, server);
|
||||
}
|
||||
|
||||
private NetIncomingMessage Receive(NetPeer peer)
|
||||
{
|
||||
NetIncomingMessage? found = null;
|
||||
|
||||
while (found == null)
|
||||
{
|
||||
switch (peer.WaitMessage(1000))
|
||||
{
|
||||
case { MessageType: NetIncomingMessageType.Data } msg:
|
||||
{
|
||||
found = msg;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return found;
|
||||
}
|
||||
|
||||
private (NetEncryption client, NetEncryption server) MakeEncryptionPair(bool disjointKey = false)
|
||||
{
|
||||
var serverKey = new byte[32];
|
||||
|
||||
System.Random.Shared.NextBytes(serverKey.AsSpan());
|
||||
|
||||
var clientKey = (byte[])serverKey.Clone();
|
||||
|
||||
if (disjointKey)
|
||||
System.Random.Shared.NextBytes(clientKey.AsSpan());
|
||||
|
||||
return (new NetEncryption(clientKey, false), new NetEncryption(serverKey, true));
|
||||
}
|
||||
}
|
||||
@@ -84,8 +84,18 @@ internal sealed class NetEncryption
|
||||
ArrayPool<byte>.Shared.Return(returnPool);
|
||||
}
|
||||
|
||||
public unsafe void Decrypt(NetIncomingMessage message)
|
||||
/// <summary>
|
||||
/// Attempts to decrypt an incoming network message, falliably.
|
||||
/// </summary>
|
||||
/// <param name="message">The message to decrypt in-place. This will be mutated with the decrypted results.</param>
|
||||
/// <returns>Whether the operation was successful. If this fails, you likely want to drop the connection.</returns>
|
||||
public unsafe bool TryDecrypt(NetIncomingMessage message)
|
||||
{
|
||||
// Minimum possible size a message can be is the nonce + 16 bytes of message.
|
||||
// So we immediately bail on anything smaller.
|
||||
if (message.LengthBytes < sizeof(ulong) + CryptoAeadXChaCha20Poly1305Ietf.AddBytes)
|
||||
return false;
|
||||
|
||||
var nonce = message.ReadUInt64();
|
||||
var cipherText = message.Data.AsSpan(sizeof(ulong), message.LengthBytes - sizeof(ulong));
|
||||
|
||||
@@ -114,7 +124,6 @@ internal sealed class NetEncryption
|
||||
|
||||
ArrayPool<byte>.Shared.Return(buffer);
|
||||
|
||||
if (!result)
|
||||
throw new SodiumException("Decryption operation failed!");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,7 +220,14 @@ namespace Robust.Shared.Network
|
||||
|
||||
// Expect login success here.
|
||||
response = await AwaitData(connection, cancel);
|
||||
encryption?.Decrypt(response);
|
||||
|
||||
// Attempt to decrypt the message, only logging if we fail to decrypt and we actually have encryption.
|
||||
if ((!encryption?.TryDecrypt(response)) ?? false)
|
||||
{
|
||||
const string msg = "Failed to decrypt login success.";
|
||||
connection.Disconnect(msg);
|
||||
throw new Exception(msg);
|
||||
}
|
||||
}
|
||||
|
||||
var msgSuc = new MsgLoginSuccess();
|
||||
|
||||
@@ -113,6 +113,11 @@ namespace Robust.Shared.Network
|
||||
[Dependency] private readonly HttpClientHolder _http = default!;
|
||||
[Dependency] private readonly IHWId _hwId = default!;
|
||||
|
||||
/// <summary>
|
||||
/// Whether we bother to log problematic packets. Set by <see cref="CVars.NetLogging"/>.
|
||||
/// </summary>
|
||||
private bool _logPacketIssues = false;
|
||||
|
||||
/// <summary>
|
||||
/// Holds lookup table for NetMessage.Id -> NetMessage.Type
|
||||
/// </summary>
|
||||
@@ -258,6 +263,7 @@ namespace Robust.Shared.Network
|
||||
_config.OnValueChanged(CVars.NetLidgrenLogError, LidgrenLogErrorChanged);
|
||||
|
||||
_config.OnValueChanged(CVars.NetVerbose, NetVerboseChanged);
|
||||
_config.OnValueChanged(CVars.NetLogging, NetLoggingChanged);
|
||||
if (isServer)
|
||||
{
|
||||
_config.OnValueChanged(CVars.AuthMode, OnAuthModeChanged, invokeImmediately: true);
|
||||
@@ -280,6 +286,11 @@ namespace Robust.Shared.Network
|
||||
}
|
||||
}
|
||||
|
||||
private void NetLoggingChanged(bool obj)
|
||||
{
|
||||
_logPacketIssues = obj;
|
||||
}
|
||||
|
||||
private void LidgrenLogWarningChanged(bool newValue)
|
||||
{
|
||||
foreach (var netPeer in _netPeers)
|
||||
@@ -870,19 +881,25 @@ namespace Robust.Shared.Network
|
||||
var peer = msg.SenderConnection.Peer;
|
||||
if (peer.Status == NetPeerStatus.ShutdownRequested)
|
||||
{
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but shutdown is requested.");
|
||||
if (_logPacketIssues)
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but shutdown is requested.");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (peer.Status == NetPeerStatus.NotRunning)
|
||||
{
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, peer is not running.");
|
||||
if (_logPacketIssues)
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, peer is not running.");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!IsConnected)
|
||||
{
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but not connected.");
|
||||
if (_logPacketIssues)
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received data message, but not connected.");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -897,19 +914,33 @@ namespace Robust.Shared.Network
|
||||
|
||||
if (msg.LengthBytes < 1)
|
||||
{
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received empty packet.");
|
||||
if (_logPacketIssues)
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Received empty packet.");
|
||||
|
||||
msg.SenderConnection.Disconnect("Received empty/weird packet", false);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!_channels.TryGetValue(msg.SenderConnection, out var channel))
|
||||
{
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got unexpected data packet before handshake completion.");
|
||||
if (_logPacketIssues)
|
||||
_logger.Debug($"{msg.SenderConnection.RemoteEndPoint}: Got unexpected data packet before handshake completion.");
|
||||
|
||||
msg.SenderConnection.Disconnect("Unexpected packet before handshake completion");
|
||||
|
||||
msg.SenderConnection.Disconnect("Unexpected packet before handshake completion", false);
|
||||
return true;
|
||||
}
|
||||
|
||||
channel.Encryption?.Decrypt(msg);
|
||||
// Attempt to decrypt the message, only logging if we fail to decrypt and we actually have encryption.
|
||||
if ((!channel.Encryption?.TryDecrypt(msg)) ?? false)
|
||||
{
|
||||
if (_logPacketIssues)
|
||||
_logger.Debug($"{msg.SenderConnection.RemoteEndPoint}: Got a packet that fails to decrypt.");
|
||||
|
||||
|
||||
msg.SenderConnection.Disconnect("Failed to decrypt packet.", false);
|
||||
return true;
|
||||
}
|
||||
|
||||
var id = msg.ReadByte();
|
||||
|
||||
@@ -917,9 +948,10 @@ namespace Robust.Shared.Network
|
||||
|
||||
if (entry == null)
|
||||
{
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got net message with invalid ID {id}.");
|
||||
if (_logPacketIssues)
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got net message with invalid ID {id}.");
|
||||
|
||||
channel.Disconnect("Got NetMessage with invalid ID");
|
||||
channel.Disconnect("Got NetMessage with invalid ID", false);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -927,9 +959,10 @@ namespace Robust.Shared.Network
|
||||
|
||||
if (!channel.IsHandshakeComplete && !entry.IsHandshake)
|
||||
{
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got non-handshake message {entry.Type.Name} before handshake completion.");
|
||||
if (_logPacketIssues)
|
||||
_logger.Warning($"{msg.SenderConnection.RemoteEndPoint}: Got non-handshake message {entry.Type.Name} before handshake completion.");
|
||||
|
||||
channel.Disconnect("Got unacceptable net message before handshake completion");
|
||||
channel.Disconnect("Got unacceptable net message before handshake completion", false);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -955,12 +988,16 @@ namespace Robust.Shared.Network
|
||||
}
|
||||
catch (InvalidCastException ice)
|
||||
{
|
||||
_logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Wrong deserialization of {type.Name} packet:\n{ice}");
|
||||
if (_logPacketIssues)
|
||||
_logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Wrong deserialization of {type.Name} packet:\n{ice}");
|
||||
channel.Disconnect("Failed to deserialize packet.", false);
|
||||
return true;
|
||||
}
|
||||
catch (Exception e) // yes, we want to catch ALL exeptions for security
|
||||
{
|
||||
_logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Failed to deserialize {type.Name} packet:\n{e}");
|
||||
if (_logPacketIssues)
|
||||
_logger.Error($"{msg.SenderConnection.RemoteEndPoint}: Failed to deserialize {type.Name} packet:\n{e}");
|
||||
channel.Disconnect("Failed to deserialize packet.", false);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user