mirror of
https://github.com/space-wizards/RobustToolbox.git
synced 2026-02-15 03:30:53 +01:00
Mapped string serializer cleanup and fixes.
This commit is contained in:
@@ -121,7 +121,7 @@ namespace Robust.Server
|
||||
mgr.RootSawmill.AddHandler(handler);
|
||||
mgr.GetSawmill("res.typecheck").Level = LogLevel.Info;
|
||||
mgr.GetSawmill("go.sys").Level = LogLevel.Info;
|
||||
mgr.GetSawmill("szr").Level = LogLevel.Info;
|
||||
// mgr.GetSawmill("szr").Level = LogLevel.Info;
|
||||
|
||||
#if DEBUG_ONLY_FCE_INFO
|
||||
#if DEBUG_ONLY_FCE_LOG
|
||||
|
||||
@@ -22,8 +22,6 @@ namespace Robust.Shared.Network.Messages
|
||||
{
|
||||
}
|
||||
|
||||
public int PackageSize { get; set; }
|
||||
|
||||
/// <value>
|
||||
/// The raw bytes of the string mapping held by the server.
|
||||
/// </value>
|
||||
@@ -31,8 +29,8 @@ namespace Robust.Shared.Network.Messages
|
||||
|
||||
public override void ReadFromBuffer(NetIncomingMessage buffer)
|
||||
{
|
||||
PackageSize = buffer.ReadVariableInt32();
|
||||
buffer.ReadBytes(Package = new byte[PackageSize]);
|
||||
var size = buffer.ReadVariableInt32();
|
||||
buffer.ReadBytes(Package = new byte[size]);
|
||||
}
|
||||
|
||||
public override void WriteToBuffer(NetOutgoingMessage buffer)
|
||||
|
||||
@@ -8,6 +8,7 @@ using System.Reflection;
|
||||
using System.Reflection.Emit;
|
||||
using System.Runtime.Serialization;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Lidgren.Network;
|
||||
using Prometheus;
|
||||
using Robust.Shared.Configuration;
|
||||
@@ -695,7 +696,15 @@ namespace Robust.Shared.Network
|
||||
|
||||
_strings.SendFullTable(channel);
|
||||
|
||||
await _serializer.Handshake(channel);
|
||||
try
|
||||
{
|
||||
await _serializer.Handshake(channel);
|
||||
}
|
||||
catch (TaskCanceledException)
|
||||
{
|
||||
// Client disconnected during handshake.
|
||||
return;
|
||||
}
|
||||
|
||||
Logger.InfoS("net", "{ConnectionEndpoint}: Connected", channel.RemoteEndPoint);
|
||||
|
||||
|
||||
@@ -92,10 +92,11 @@ namespace Robust.Shared.Serialization
|
||||
using var zs = new DeflateStream(stream, CompressionMode.Decompress, true);
|
||||
using var hasherStream = new HasherStream(zs, hasher, true);
|
||||
|
||||
var count = ReadCompressedUnsignedInt(hasherStream, out _);
|
||||
Primitives.ReadPrimitive(hasherStream, out uint count);
|
||||
for (var i = 0; i < count; ++i)
|
||||
{
|
||||
var l = (int) ReadCompressedUnsignedInt(hasherStream, out _);
|
||||
Primitives.ReadPrimitive(hasherStream, out uint lu);
|
||||
var l = (int) lu;
|
||||
var y = hasherStream.Read(buf, 0, l);
|
||||
if (y != l)
|
||||
{
|
||||
@@ -126,7 +127,7 @@ namespace Robust.Shared.Serialization
|
||||
using (var zs = new DeflateStream(stream, CompressionLevel.Optimal, true))
|
||||
{
|
||||
using var hasherStream = new HasherStream(zs, hasher, true);
|
||||
WriteCompressedUnsignedInt(hasherStream, (uint) strings.Length);
|
||||
Primitives.WritePrimitive(hasherStream, (uint) strings.Length);
|
||||
|
||||
foreach (var str in strings)
|
||||
{
|
||||
@@ -139,7 +140,7 @@ namespace Robust.Shared.Serialization
|
||||
throw new NotImplementedException("Overly long string in strings package.");
|
||||
}
|
||||
|
||||
WriteCompressedUnsignedInt(hasherStream, (uint) l);
|
||||
Primitives.WritePrimitive(hasherStream, (uint) l);
|
||||
hasherStream.Write(buf[..l]);
|
||||
}
|
||||
}
|
||||
@@ -225,43 +226,21 @@ namespace Robust.Shared.Serialization
|
||||
|
||||
if (str.Contains('/'))
|
||||
{
|
||||
var parts = str.Split('/', StringSplitOptions.RemoveEmptyEntries);
|
||||
for (var i = 0; i < parts.Length; ++i)
|
||||
foreach (var substr in str.Split("/", StringSplitOptions.RemoveEmptyEntries))
|
||||
{
|
||||
for (var l = 1; l <= parts.Length - i; ++l)
|
||||
{
|
||||
var subStr = string.Join('/', parts.Skip(i).Take(l));
|
||||
if (!TryAddString(subStr))
|
||||
continue;
|
||||
|
||||
if (!subStr.Contains('.'))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var subParts = subStr.Split('.', StringSplitOptions.RemoveEmptyEntries);
|
||||
for (var si = 0; si < subParts.Length; ++si)
|
||||
{
|
||||
for (var sl = 1; sl <= subParts.Length - si; ++sl)
|
||||
{
|
||||
var subSubStr = string.Join('.', subParts.Skip(si).Take(sl));
|
||||
// ReSharper disable once InvertIf
|
||||
TryAddString(subSubStr);
|
||||
}
|
||||
}
|
||||
}
|
||||
AddString(substr);
|
||||
}
|
||||
}
|
||||
else if (str.Contains("_"))
|
||||
{
|
||||
foreach (var substr in str.Split("_"))
|
||||
foreach (var substr in str.Split("_", StringSplitOptions.RemoveEmptyEntries))
|
||||
{
|
||||
AddString(substr);
|
||||
}
|
||||
}
|
||||
else if (str.Contains(" "))
|
||||
{
|
||||
foreach (var substr in str.Split(" "))
|
||||
foreach (var substr in str.Split(" ", StringSplitOptions.RemoveEmptyEntries))
|
||||
{
|
||||
if (substr == str) continue;
|
||||
|
||||
@@ -283,7 +262,7 @@ namespace Robust.Shared.Serialization
|
||||
for (var sl = 1; sl <= parts.Length - si; ++sl)
|
||||
{
|
||||
var subSubStr = String.Concat(parts.Skip(si).Take(sl));
|
||||
TryAddString(subSubStr);
|
||||
AddString(subSubStr);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -460,7 +439,7 @@ namespace Robust.Shared.Serialization
|
||||
|
||||
if (value == null)
|
||||
{
|
||||
WriteCompressedUnsignedInt(stream, MappedNull);
|
||||
Primitives.WritePrimitive(stream, MappedNull);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -473,14 +452,13 @@ namespace Robust.Shared.Serialization
|
||||
"A string mapping outside of the mapped string table was encountered.");
|
||||
}
|
||||
#endif
|
||||
WriteCompressedUnsignedInt(stream, (uint) mapping + FirstMappedIndexStart);
|
||||
Primitives.WritePrimitive(stream, (uint) mapping + FirstMappedIndexStart);
|
||||
//Logger.DebugS("szr", $"Encoded mapped string: {value}");
|
||||
return;
|
||||
}
|
||||
|
||||
// indicate not mapped
|
||||
WriteCompressedUnsignedInt(stream, UnmappedString);
|
||||
|
||||
Primitives.WritePrimitive(stream, UnmappedString);
|
||||
Primitives.WritePrimitive(stream, value);
|
||||
}
|
||||
|
||||
@@ -489,7 +467,7 @@ namespace Robust.Shared.Serialization
|
||||
{
|
||||
DebugTools.Assert(Locked);
|
||||
|
||||
var mapIndex = ReadCompressedUnsignedInt(stream, out _);
|
||||
Primitives.ReadPrimitive(stream, out uint mapIndex);
|
||||
if (mapIndex == MappedNull)
|
||||
{
|
||||
value = null;
|
||||
|
||||
@@ -4,11 +4,9 @@ using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text.RegularExpressions;
|
||||
using System.Threading.Tasks;
|
||||
using JetBrains.Annotations;
|
||||
using NetSerializer;
|
||||
using Newtonsoft.Json.Linq;
|
||||
using Robust.Shared.ContentPack;
|
||||
@@ -16,6 +14,7 @@ using Robust.Shared.Interfaces.Log;
|
||||
using Robust.Shared.Interfaces.Network;
|
||||
using Robust.Shared.IoC;
|
||||
using Robust.Shared.Log;
|
||||
using Robust.Shared.Network;
|
||||
using Robust.Shared.Network.Messages;
|
||||
using Robust.Shared.Utility;
|
||||
using YamlDotNet.RepresentationModel;
|
||||
@@ -41,7 +40,7 @@ namespace Robust.Shared.Serialization
|
||||
/// send the constant value instead - and at the other end, the
|
||||
/// serializer can use the same mapping to recover the original string.
|
||||
/// </remarks>
|
||||
internal partial class RobustMappedStringSerializer : IStaticTypeSerializer, IRobustMappedStringSerializer
|
||||
internal sealed partial class RobustMappedStringSerializer : IStaticTypeSerializer, IRobustMappedStringSerializer
|
||||
{
|
||||
private delegate void WriteStringDelegate(Stream stream, string? value);
|
||||
|
||||
@@ -79,13 +78,13 @@ namespace Robust.Shared.Serialization
|
||||
/// The special value corresponding to a <c>null</c> string in the
|
||||
/// encoding.
|
||||
/// </summary>
|
||||
private const int MappedNull = 0;
|
||||
private const uint MappedNull = 0;
|
||||
|
||||
/// <summary>
|
||||
/// The special value corresponding to a string which was not mapped.
|
||||
/// This is followed by the bytes of the unmapped string.
|
||||
/// </summary>
|
||||
private const int UnmappedString = 1;
|
||||
private const uint UnmappedString = 1;
|
||||
|
||||
/// <summary>
|
||||
/// The first non-special value, used for encoding mapped strings.
|
||||
@@ -97,7 +96,7 @@ namespace Robust.Shared.Serialization
|
||||
/// <c>>= FirstMappedIndexStart</c> represents the string with
|
||||
/// mapping of that value <c> - FirstMappedIndexStart</c>.
|
||||
/// </remarks>
|
||||
private const int FirstMappedIndexStart = 2;
|
||||
private const uint FirstMappedIndexStart = 2;
|
||||
|
||||
[Dependency] private readonly INetManager _net = default!;
|
||||
|
||||
@@ -107,7 +106,8 @@ namespace Robust.Shared.Serialization
|
||||
|
||||
private MappedStringDict _dict = default!;
|
||||
|
||||
private readonly HashSet<INetChannel> _incompleteHandshakes = new HashSet<INetChannel>();
|
||||
private readonly Dictionary<INetChannel, InProgressHandshake> _incompleteHandshakes
|
||||
= new Dictionary<INetChannel, InProgressHandshake>();
|
||||
|
||||
private byte[]? _mappedStringsPackage;
|
||||
private byte[]? _serverHash;
|
||||
@@ -147,23 +147,20 @@ namespace Robust.Shared.Serialization
|
||||
/// </remarks>
|
||||
/// <seealso cref="MsgMapStrClientHandshake"/>
|
||||
/// <seealso cref="MsgMapStrStrings"/>
|
||||
public async Task Handshake(INetChannel channel)
|
||||
public Task Handshake(INetChannel channel)
|
||||
{
|
||||
DebugTools.Assert(_net.IsServer);
|
||||
DebugTools.Assert(_dict.Locked);
|
||||
|
||||
_incompleteHandshakes.Add(channel);
|
||||
var tcs = new TaskCompletionSource<object?>();
|
||||
|
||||
_incompleteHandshakes.Add(channel, new InProgressHandshake(tcs));
|
||||
|
||||
var message = _net.CreateNetMessage<MsgMapStrServerHandshake>();
|
||||
message.Hash = _stringMapHash;
|
||||
_net.ServerSendMessage(message, channel);
|
||||
|
||||
while (_incompleteHandshakes.Contains(channel))
|
||||
{
|
||||
await Task.Delay(1);
|
||||
}
|
||||
|
||||
LogSzr.Debug($"Completed handshake with {channel.RemoteEndPoint.Address}.");
|
||||
return tcs.Task;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -219,6 +216,22 @@ namespace Robust.Shared.Serialization
|
||||
_net.RegisterNetMessage<MsgMapStrServerHandshake>(nameof(MsgMapStrServerHandshake), HandleServerHandshake);
|
||||
_net.RegisterNetMessage<MsgMapStrClientHandshake>(nameof(MsgMapStrClientHandshake), HandleClientHandshake);
|
||||
_net.RegisterNetMessage<MsgMapStrStrings>(nameof(MsgMapStrStrings), HandleStringsMessage);
|
||||
|
||||
_net.Disconnect += NetOnDisconnect;
|
||||
}
|
||||
|
||||
private void NetOnDisconnect(object? sender, NetDisconnectedArgs e)
|
||||
{
|
||||
// Cancel handshake in-progress if client disconnects mid-handshake.
|
||||
var channel = e.Channel;
|
||||
if (_incompleteHandshakes.TryGetValue(channel, out var handshake))
|
||||
{
|
||||
var tcs = handshake.Tcs;
|
||||
LogSzr.Debug($"Cancelling handshake for disconnected client {channel.SessionId}");
|
||||
tcs.SetCanceled();
|
||||
}
|
||||
|
||||
_incompleteHandshakes.Remove(channel);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -320,23 +333,38 @@ namespace Robust.Shared.Serialization
|
||||
DebugTools.Assert(_net.IsServer);
|
||||
DebugTools.Assert(_dict.Locked);
|
||||
|
||||
LogSzr.Debug($"Received handshake from {msgMapStr.MsgChannel.RemoteEndPoint.Address}.");
|
||||
var channel = msgMapStr.MsgChannel;
|
||||
LogSzr.Debug($"Received handshake from {channel.SessionId}.");
|
||||
|
||||
if (!msgMapStr.NeedsStrings)
|
||||
if (!_incompleteHandshakes.TryGetValue(channel, out var handshake))
|
||||
{
|
||||
LogSzr.Debug($"Completing handshake with {msgMapStr.MsgChannel.RemoteEndPoint.Address}.");
|
||||
_incompleteHandshakes.Remove(msgMapStr.MsgChannel);
|
||||
channel.Disconnect("MsgMapStrClientHandshake without in-progress handshake.");
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: count and limit number of requests to send strings during handshake
|
||||
if (!msgMapStr.NeedsStrings)
|
||||
{
|
||||
LogSzr.Debug($"Completing handshake with {channel.SessionId}.");
|
||||
|
||||
handshake.Tcs.SetResult(null);
|
||||
_incompleteHandshakes.Remove(channel);
|
||||
return;
|
||||
}
|
||||
|
||||
if (handshake.HasRequestedStrings)
|
||||
{
|
||||
channel.Disconnect("Cannot request strings twice");
|
||||
return;
|
||||
}
|
||||
|
||||
handshake.HasRequestedStrings = true;
|
||||
|
||||
var strings = _net.CreateNetMessage<MsgMapStrStrings>();
|
||||
strings.Package = _mappedStringsPackage;
|
||||
LogSzr.Debug(
|
||||
$"Sending {_mappedStringsPackage!.Length} bytes sized mapped strings package to {msgMapStr.MsgChannel.SessionId}.");
|
||||
$"Sending {_mappedStringsPackage!.Length} bytes sized mapped strings package to {channel.SessionId}.");
|
||||
|
||||
_net.ServerSendMessage(strings, msgMapStr.MsgChannel);
|
||||
_net.ServerSendMessage(strings, channel);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -595,77 +623,6 @@ namespace Robust.Shared.Serialization
|
||||
mss._dict.ReadMappedString(stream, out value);
|
||||
}
|
||||
|
||||
// TODO: move the below methods to some stream helpers class
|
||||
|
||||
#if ROBUST_SERIALIZER_DISABLE_COMPRESSED_UINTS
|
||||
private static int WriteCompressedUnsignedInt(Stream stream, uint value)
|
||||
{
|
||||
WriteUnsignedInt(stream, value);
|
||||
return 4;
|
||||
}
|
||||
|
||||
private static uint ReadCompressedUnsignedInt(Stream stream, out int byteCount)
|
||||
{
|
||||
byteCount = 4;
|
||||
return ReadUnsignedInt(stream);
|
||||
}
|
||||
#else
|
||||
private static int WriteCompressedUnsignedInt(Stream stream, uint value)
|
||||
{
|
||||
var length = 1;
|
||||
while (value >= 0x80)
|
||||
{
|
||||
stream.WriteByte((byte) (0x80 | value));
|
||||
value >>= 7;
|
||||
++length;
|
||||
}
|
||||
|
||||
stream.WriteByte((byte) value);
|
||||
return length;
|
||||
}
|
||||
|
||||
private static uint ReadCompressedUnsignedInt(Stream stream, out int byteCount)
|
||||
{
|
||||
byteCount = 0;
|
||||
var value = 0u;
|
||||
var shift = 0;
|
||||
while (stream.CanRead)
|
||||
{
|
||||
var current = stream.ReadByte();
|
||||
++byteCount;
|
||||
if (current == -1)
|
||||
{
|
||||
throw new EndOfStreamException();
|
||||
}
|
||||
|
||||
value |= (0x7Fu & (byte) current) << shift;
|
||||
shift += 7;
|
||||
if ((0x80 & current) == 0)
|
||||
{
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
throw new EndOfStreamException();
|
||||
}
|
||||
#endif
|
||||
|
||||
[UsedImplicitly]
|
||||
private static unsafe void WriteUnsignedInt(Stream stream, uint value)
|
||||
{
|
||||
var bytes = MemoryMarshal.AsBytes(new ReadOnlySpan<uint>(&value, 1));
|
||||
stream.Write(bytes);
|
||||
}
|
||||
|
||||
[UsedImplicitly]
|
||||
private static unsafe uint ReadUnsignedInt(Stream stream)
|
||||
{
|
||||
uint value;
|
||||
var bytes = MemoryMarshal.AsBytes(new Span<uint>(&value, 1));
|
||||
stream.Read(bytes);
|
||||
return value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// See <see cref="OnClientCompleteHandshake"/>.
|
||||
/// </summary>
|
||||
@@ -681,7 +638,7 @@ namespace Robust.Shared.Serialization
|
||||
|
||||
(_stringMapHash, _mappedStringsPackage) = _dict.GeneratePackage();
|
||||
|
||||
LogSzr.Debug($"Wrote string package in {sw.ElapsedMilliseconds}ms");
|
||||
LogSzr.Debug($"Wrote string package in {sw.ElapsedMilliseconds}ms size {ByteHelpers.FormatBytes(_mappedStringsPackage.Length)}");
|
||||
LogSzr.Debug($"String hash is {ConvertToBase64Url(_stringMapHash)}");
|
||||
}
|
||||
|
||||
@@ -698,5 +655,16 @@ namespace Robust.Shared.Serialization
|
||||
|
||||
NetworkInitialize();
|
||||
}
|
||||
|
||||
private sealed class InProgressHandshake
|
||||
{
|
||||
public readonly TaskCompletionSource<object?> Tcs;
|
||||
public bool HasRequestedStrings;
|
||||
|
||||
public InProgressHandshake(TaskCompletionSource<object?> tcs)
|
||||
{
|
||||
Tcs = tcs;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
using Lidgren.Network;
|
||||
using Robust.Shared.Interfaces.Network;
|
||||
|
||||
namespace Robust.Shared.Serialization
|
||||
|
||||
37
Robust.Shared/Utility/ByteHelpers.cs
Normal file
37
Robust.Shared/Utility/ByteHelpers.cs
Normal file
@@ -0,0 +1,37 @@
|
||||
using System;
|
||||
|
||||
namespace Robust.Shared.Utility
|
||||
{
|
||||
public static class ByteHelpers
|
||||
{
|
||||
public static string FormatKibibytes(long bytes)
|
||||
{
|
||||
return $"{bytes / 1024} KiB";
|
||||
}
|
||||
|
||||
public static string FormatBytes(long bytes)
|
||||
{
|
||||
double d = bytes;
|
||||
var i = 0;
|
||||
for (; i < ByteSuffixes.Length && d >= 1024; i++)
|
||||
{
|
||||
d /= 1024;
|
||||
}
|
||||
|
||||
return $"{Math.Round(d, 2)} {ByteSuffixes[i]}";
|
||||
}
|
||||
|
||||
private static readonly string[] ByteSuffixes =
|
||||
{
|
||||
"B",
|
||||
"KiB",
|
||||
"MiB",
|
||||
"GiB",
|
||||
"TiB",
|
||||
"PiB",
|
||||
"EiB",
|
||||
"ZiB",
|
||||
"YiB"
|
||||
};
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user