Mapped string serializer cleanup and fixes.

This commit is contained in:
Pieter-Jan Briers
2020-08-31 22:58:27 +02:00
parent 59a003cda0
commit 36f29d54ed
7 changed files with 126 additions and 137 deletions

View File

@@ -121,7 +121,7 @@ namespace Robust.Server
mgr.RootSawmill.AddHandler(handler);
mgr.GetSawmill("res.typecheck").Level = LogLevel.Info;
mgr.GetSawmill("go.sys").Level = LogLevel.Info;
mgr.GetSawmill("szr").Level = LogLevel.Info;
// mgr.GetSawmill("szr").Level = LogLevel.Info;
#if DEBUG_ONLY_FCE_INFO
#if DEBUG_ONLY_FCE_LOG

View File

@@ -22,8 +22,6 @@ namespace Robust.Shared.Network.Messages
{
}
public int PackageSize { get; set; }
/// <value>
/// The raw bytes of the string mapping held by the server.
/// </value>
@@ -31,8 +29,8 @@ namespace Robust.Shared.Network.Messages
public override void ReadFromBuffer(NetIncomingMessage buffer)
{
PackageSize = buffer.ReadVariableInt32();
buffer.ReadBytes(Package = new byte[PackageSize]);
var size = buffer.ReadVariableInt32();
buffer.ReadBytes(Package = new byte[size]);
}
public override void WriteToBuffer(NetOutgoingMessage buffer)

View File

@@ -8,6 +8,7 @@ using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.Serialization;
using System.Threading;
using System.Threading.Tasks;
using Lidgren.Network;
using Prometheus;
using Robust.Shared.Configuration;
@@ -695,7 +696,15 @@ namespace Robust.Shared.Network
_strings.SendFullTable(channel);
await _serializer.Handshake(channel);
try
{
await _serializer.Handshake(channel);
}
catch (TaskCanceledException)
{
// Client disconnected during handshake.
return;
}
Logger.InfoS("net", "{ConnectionEndpoint}: Connected", channel.RemoteEndPoint);

View File

@@ -92,10 +92,11 @@ namespace Robust.Shared.Serialization
using var zs = new DeflateStream(stream, CompressionMode.Decompress, true);
using var hasherStream = new HasherStream(zs, hasher, true);
var count = ReadCompressedUnsignedInt(hasherStream, out _);
Primitives.ReadPrimitive(hasherStream, out uint count);
for (var i = 0; i < count; ++i)
{
var l = (int) ReadCompressedUnsignedInt(hasherStream, out _);
Primitives.ReadPrimitive(hasherStream, out uint lu);
var l = (int) lu;
var y = hasherStream.Read(buf, 0, l);
if (y != l)
{
@@ -126,7 +127,7 @@ namespace Robust.Shared.Serialization
using (var zs = new DeflateStream(stream, CompressionLevel.Optimal, true))
{
using var hasherStream = new HasherStream(zs, hasher, true);
WriteCompressedUnsignedInt(hasherStream, (uint) strings.Length);
Primitives.WritePrimitive(hasherStream, (uint) strings.Length);
foreach (var str in strings)
{
@@ -139,7 +140,7 @@ namespace Robust.Shared.Serialization
throw new NotImplementedException("Overly long string in strings package.");
}
WriteCompressedUnsignedInt(hasherStream, (uint) l);
Primitives.WritePrimitive(hasherStream, (uint) l);
hasherStream.Write(buf[..l]);
}
}
@@ -225,43 +226,21 @@ namespace Robust.Shared.Serialization
if (str.Contains('/'))
{
var parts = str.Split('/', StringSplitOptions.RemoveEmptyEntries);
for (var i = 0; i < parts.Length; ++i)
foreach (var substr in str.Split("/", StringSplitOptions.RemoveEmptyEntries))
{
for (var l = 1; l <= parts.Length - i; ++l)
{
var subStr = string.Join('/', parts.Skip(i).Take(l));
if (!TryAddString(subStr))
continue;
if (!subStr.Contains('.'))
{
continue;
}
var subParts = subStr.Split('.', StringSplitOptions.RemoveEmptyEntries);
for (var si = 0; si < subParts.Length; ++si)
{
for (var sl = 1; sl <= subParts.Length - si; ++sl)
{
var subSubStr = string.Join('.', subParts.Skip(si).Take(sl));
// ReSharper disable once InvertIf
TryAddString(subSubStr);
}
}
}
AddString(substr);
}
}
else if (str.Contains("_"))
{
foreach (var substr in str.Split("_"))
foreach (var substr in str.Split("_", StringSplitOptions.RemoveEmptyEntries))
{
AddString(substr);
}
}
else if (str.Contains(" "))
{
foreach (var substr in str.Split(" "))
foreach (var substr in str.Split(" ", StringSplitOptions.RemoveEmptyEntries))
{
if (substr == str) continue;
@@ -283,7 +262,7 @@ namespace Robust.Shared.Serialization
for (var sl = 1; sl <= parts.Length - si; ++sl)
{
var subSubStr = String.Concat(parts.Skip(si).Take(sl));
TryAddString(subSubStr);
AddString(subSubStr);
}
}
}
@@ -460,7 +439,7 @@ namespace Robust.Shared.Serialization
if (value == null)
{
WriteCompressedUnsignedInt(stream, MappedNull);
Primitives.WritePrimitive(stream, MappedNull);
return;
}
@@ -473,14 +452,13 @@ namespace Robust.Shared.Serialization
"A string mapping outside of the mapped string table was encountered.");
}
#endif
WriteCompressedUnsignedInt(stream, (uint) mapping + FirstMappedIndexStart);
Primitives.WritePrimitive(stream, (uint) mapping + FirstMappedIndexStart);
//Logger.DebugS("szr", $"Encoded mapped string: {value}");
return;
}
// indicate not mapped
WriteCompressedUnsignedInt(stream, UnmappedString);
Primitives.WritePrimitive(stream, UnmappedString);
Primitives.WritePrimitive(stream, value);
}
@@ -489,7 +467,7 @@ namespace Robust.Shared.Serialization
{
DebugTools.Assert(Locked);
var mapIndex = ReadCompressedUnsignedInt(stream, out _);
Primitives.ReadPrimitive(stream, out uint mapIndex);
if (mapIndex == MappedNull)
{
value = null;

View File

@@ -4,11 +4,9 @@ using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using JetBrains.Annotations;
using NetSerializer;
using Newtonsoft.Json.Linq;
using Robust.Shared.ContentPack;
@@ -16,6 +14,7 @@ using Robust.Shared.Interfaces.Log;
using Robust.Shared.Interfaces.Network;
using Robust.Shared.IoC;
using Robust.Shared.Log;
using Robust.Shared.Network;
using Robust.Shared.Network.Messages;
using Robust.Shared.Utility;
using YamlDotNet.RepresentationModel;
@@ -41,7 +40,7 @@ namespace Robust.Shared.Serialization
/// send the constant value instead - and at the other end, the
/// serializer can use the same mapping to recover the original string.
/// </remarks>
internal partial class RobustMappedStringSerializer : IStaticTypeSerializer, IRobustMappedStringSerializer
internal sealed partial class RobustMappedStringSerializer : IStaticTypeSerializer, IRobustMappedStringSerializer
{
private delegate void WriteStringDelegate(Stream stream, string? value);
@@ -79,13 +78,13 @@ namespace Robust.Shared.Serialization
/// The special value corresponding to a <c>null</c> string in the
/// encoding.
/// </summary>
private const int MappedNull = 0;
private const uint MappedNull = 0;
/// <summary>
/// The special value corresponding to a string which was not mapped.
/// This is followed by the bytes of the unmapped string.
/// </summary>
private const int UnmappedString = 1;
private const uint UnmappedString = 1;
/// <summary>
/// The first non-special value, used for encoding mapped strings.
@@ -97,7 +96,7 @@ namespace Robust.Shared.Serialization
/// <c>>= FirstMappedIndexStart</c> represents the string with
/// mapping of that value <c> - FirstMappedIndexStart</c>.
/// </remarks>
private const int FirstMappedIndexStart = 2;
private const uint FirstMappedIndexStart = 2;
[Dependency] private readonly INetManager _net = default!;
@@ -107,7 +106,8 @@ namespace Robust.Shared.Serialization
private MappedStringDict _dict = default!;
private readonly HashSet<INetChannel> _incompleteHandshakes = new HashSet<INetChannel>();
private readonly Dictionary<INetChannel, InProgressHandshake> _incompleteHandshakes
= new Dictionary<INetChannel, InProgressHandshake>();
private byte[]? _mappedStringsPackage;
private byte[]? _serverHash;
@@ -147,23 +147,20 @@ namespace Robust.Shared.Serialization
/// </remarks>
/// <seealso cref="MsgMapStrClientHandshake"/>
/// <seealso cref="MsgMapStrStrings"/>
public async Task Handshake(INetChannel channel)
public Task Handshake(INetChannel channel)
{
DebugTools.Assert(_net.IsServer);
DebugTools.Assert(_dict.Locked);
_incompleteHandshakes.Add(channel);
var tcs = new TaskCompletionSource<object?>();
_incompleteHandshakes.Add(channel, new InProgressHandshake(tcs));
var message = _net.CreateNetMessage<MsgMapStrServerHandshake>();
message.Hash = _stringMapHash;
_net.ServerSendMessage(message, channel);
while (_incompleteHandshakes.Contains(channel))
{
await Task.Delay(1);
}
LogSzr.Debug($"Completed handshake with {channel.RemoteEndPoint.Address}.");
return tcs.Task;
}
/// <summary>
@@ -219,6 +216,22 @@ namespace Robust.Shared.Serialization
_net.RegisterNetMessage<MsgMapStrServerHandshake>(nameof(MsgMapStrServerHandshake), HandleServerHandshake);
_net.RegisterNetMessage<MsgMapStrClientHandshake>(nameof(MsgMapStrClientHandshake), HandleClientHandshake);
_net.RegisterNetMessage<MsgMapStrStrings>(nameof(MsgMapStrStrings), HandleStringsMessage);
_net.Disconnect += NetOnDisconnect;
}
private void NetOnDisconnect(object? sender, NetDisconnectedArgs e)
{
// Cancel handshake in-progress if client disconnects mid-handshake.
var channel = e.Channel;
if (_incompleteHandshakes.TryGetValue(channel, out var handshake))
{
var tcs = handshake.Tcs;
LogSzr.Debug($"Cancelling handshake for disconnected client {channel.SessionId}");
tcs.SetCanceled();
}
_incompleteHandshakes.Remove(channel);
}
/// <summary>
@@ -320,23 +333,38 @@ namespace Robust.Shared.Serialization
DebugTools.Assert(_net.IsServer);
DebugTools.Assert(_dict.Locked);
LogSzr.Debug($"Received handshake from {msgMapStr.MsgChannel.RemoteEndPoint.Address}.");
var channel = msgMapStr.MsgChannel;
LogSzr.Debug($"Received handshake from {channel.SessionId}.");
if (!msgMapStr.NeedsStrings)
if (!_incompleteHandshakes.TryGetValue(channel, out var handshake))
{
LogSzr.Debug($"Completing handshake with {msgMapStr.MsgChannel.RemoteEndPoint.Address}.");
_incompleteHandshakes.Remove(msgMapStr.MsgChannel);
channel.Disconnect("MsgMapStrClientHandshake without in-progress handshake.");
return;
}
// TODO: count and limit number of requests to send strings during handshake
if (!msgMapStr.NeedsStrings)
{
LogSzr.Debug($"Completing handshake with {channel.SessionId}.");
handshake.Tcs.SetResult(null);
_incompleteHandshakes.Remove(channel);
return;
}
if (handshake.HasRequestedStrings)
{
channel.Disconnect("Cannot request strings twice");
return;
}
handshake.HasRequestedStrings = true;
var strings = _net.CreateNetMessage<MsgMapStrStrings>();
strings.Package = _mappedStringsPackage;
LogSzr.Debug(
$"Sending {_mappedStringsPackage!.Length} bytes sized mapped strings package to {msgMapStr.MsgChannel.SessionId}.");
$"Sending {_mappedStringsPackage!.Length} bytes sized mapped strings package to {channel.SessionId}.");
_net.ServerSendMessage(strings, msgMapStr.MsgChannel);
_net.ServerSendMessage(strings, channel);
}
/// <summary>
@@ -595,77 +623,6 @@ namespace Robust.Shared.Serialization
mss._dict.ReadMappedString(stream, out value);
}
// TODO: move the below methods to some stream helpers class
#if ROBUST_SERIALIZER_DISABLE_COMPRESSED_UINTS
private static int WriteCompressedUnsignedInt(Stream stream, uint value)
{
WriteUnsignedInt(stream, value);
return 4;
}
private static uint ReadCompressedUnsignedInt(Stream stream, out int byteCount)
{
byteCount = 4;
return ReadUnsignedInt(stream);
}
#else
private static int WriteCompressedUnsignedInt(Stream stream, uint value)
{
var length = 1;
while (value >= 0x80)
{
stream.WriteByte((byte) (0x80 | value));
value >>= 7;
++length;
}
stream.WriteByte((byte) value);
return length;
}
private static uint ReadCompressedUnsignedInt(Stream stream, out int byteCount)
{
byteCount = 0;
var value = 0u;
var shift = 0;
while (stream.CanRead)
{
var current = stream.ReadByte();
++byteCount;
if (current == -1)
{
throw new EndOfStreamException();
}
value |= (0x7Fu & (byte) current) << shift;
shift += 7;
if ((0x80 & current) == 0)
{
return value;
}
}
throw new EndOfStreamException();
}
#endif
[UsedImplicitly]
private static unsafe void WriteUnsignedInt(Stream stream, uint value)
{
var bytes = MemoryMarshal.AsBytes(new ReadOnlySpan<uint>(&value, 1));
stream.Write(bytes);
}
[UsedImplicitly]
private static unsafe uint ReadUnsignedInt(Stream stream)
{
uint value;
var bytes = MemoryMarshal.AsBytes(new Span<uint>(&value, 1));
stream.Read(bytes);
return value;
}
/// <summary>
/// See <see cref="OnClientCompleteHandshake"/>.
/// </summary>
@@ -681,7 +638,7 @@ namespace Robust.Shared.Serialization
(_stringMapHash, _mappedStringsPackage) = _dict.GeneratePackage();
LogSzr.Debug($"Wrote string package in {sw.ElapsedMilliseconds}ms");
LogSzr.Debug($"Wrote string package in {sw.ElapsedMilliseconds}ms size {ByteHelpers.FormatBytes(_mappedStringsPackage.Length)}");
LogSzr.Debug($"String hash is {ConvertToBase64Url(_stringMapHash)}");
}
@@ -698,5 +655,16 @@ namespace Robust.Shared.Serialization
NetworkInitialize();
}
private sealed class InProgressHandshake
{
public readonly TaskCompletionSource<object?> Tcs;
public bool HasRequestedStrings;
public InProgressHandshake(TaskCompletionSource<object?> tcs)
{
Tcs = tcs;
}
}
}
}

View File

@@ -1,6 +1,5 @@
using System;
using System.Threading.Tasks;
using Lidgren.Network;
using Robust.Shared.Interfaces.Network;
namespace Robust.Shared.Serialization

View File

@@ -0,0 +1,37 @@
using System;
namespace Robust.Shared.Utility
{
public static class ByteHelpers
{
public static string FormatKibibytes(long bytes)
{
return $"{bytes / 1024} KiB";
}
public static string FormatBytes(long bytes)
{
double d = bytes;
var i = 0;
for (; i < ByteSuffixes.Length && d >= 1024; i++)
{
d /= 1024;
}
return $"{Math.Round(d, 2)} {ByteSuffixes[i]}";
}
private static readonly string[] ByteSuffixes =
{
"B",
"KiB",
"MiB",
"GiB",
"TiB",
"PiB",
"EiB",
"ZiB",
"YiB"
};
}
}