From 0d52def877340f2c22e72b94e5a94b15ed6b6c2c Mon Sep 17 00:00:00 2001 From: Tyler Young Date: Thu, 11 Jun 2020 22:09:55 -0400 Subject: [PATCH] Have RobustSerializer use a shared string dictionary (#1117) * implements shared string dictionary and handshake from net-code-2 * fix unit test switch to szr sawmill * try to silence some warnings around ZipEntry * rebase and use system zip instead of icsharplib fix rebase artifacts * Update Robust.Shared/Interfaces/GameObjects/IComponentFactory.cs * Update Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.cs * Update Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.cs * Apply suggestions from code review * Apply suggestions from code review * Update Robust.Shared/Serialization/RobustSerializer.cs * since no longer gathering from paths, make string splitting more robust * make string gathering ignore strings under 4 chars long make string gathering yet more robust * add limit to size of mapped strings * add more string data to feed into shared string dictionary from YAML files add JSON importer but don't parse RSI metadata yet fix typo that breaks nulls in MappedStringSerializer minor refactoring make string splitting more robust add WriteUnsignedInt / ReadUnsignedInt for validating WriteCompressedUnsignedInt / ReadCompressedUnsignedInt aren't bogus * comment out some log statements * minor refactor, reorder logging add null check due to smart typing NRT checks * Add doc comments, readability improvements to MappedStringSerializer The protocol, handshake, and internal logic are now more documented. The main area that could still be improved is the documentation of how the cache system works, but the code is readable enough for now that it isn't immediately necessary. * add documentation, organization * update some more doc comments * add flows to doc comment for NetworkInitialize * more documentation and organization * more docs * instead of retrieving INetManager by IoC, assign when NetworkInitialize is invoked * "document" the regex * Update Robust.Shared/Network/NetManager.cs * add missing check for LockMappedStrings * Update Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.cs Co-authored-by: ComicIronic * change to warning instead of throw for unlocked string mapping Co-authored-by: ComicIronic --- Robust.Client/GameController.cs | 2 +- .../GameStates/ServerGameStateManager.cs | 9 +- Robust.Server/Maps/MapLoader.cs | 1 + Robust.Server/Player/PlayerManager.cs | 18 +- Robust.Shared/ContentPack/DirLoader.cs | 33 + Robust.Shared/ContentPack/IContentRoot.cs | 6 + Robust.Shared/ContentPack/IModLoader.cs | 2 + Robust.Shared/ContentPack/PackLoader.cs | 14 + .../ResourceManager.SingleStreamLoader.cs | 5 + Robust.Shared/ContentPack/ResourceManager.cs | 2 + Robust.Shared/GameObjects/ComponentFactory.cs | 5 + .../GameObjects/IComponentFactory.cs | 7 + .../Serialization/IRobustSerializer.cs | 6 + .../Localization/LocalizationManager.cs | 3 + Robust.Shared/Network/NetChannel.cs | 6 + Robust.Shared/Network/NetManager.cs | 29 +- Robust.Shared/Network/StringTable.cs | 3 + Robust.Shared/Prototypes/PrototypeManager.cs | 10 +- Robust.Shared/Robust.Shared.csproj | 15 + .../RobustSerializer.Handshake.cs | 35 + ...ppedStringSerializer.MsgClientHandshake.cs | 54 + ...ppedStringSerializer.MsgServerHandshake.cs | 65 + ...lizer.MappedStringSerializer.MsgStrings.cs | 72 + ...RobustSerializer.MappedStringSerializer.cs | 1295 +++++++++++++++++ .../Serialization/RobustSerializer.cs | 106 +- .../Serialization/YamlObjectSerializer.cs | 2 +- Robust.Shared/Utility/TypeAbbreviation.cs | 3 + .../Shared/GameObjects/EntityState_Tests.cs | 28 +- 28 files changed, 1793 insertions(+), 43 deletions(-) create mode 100644 Robust.Shared/Serialization/RobustSerializer.Handshake.cs create mode 100644 Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgClientHandshake.cs create mode 100644 Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgServerHandshake.cs create mode 100644 Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgStrings.cs create mode 100644 Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.cs diff --git a/Robust.Client/GameController.cs b/Robust.Client/GameController.cs index a6f1e7c03..ea2ffe094 100644 --- a/Robust.Client/GameController.cs +++ b/Robust.Client/GameController.cs @@ -160,9 +160,9 @@ namespace Robust.Client _modLoader.BroadcastRunLevel(ModRunLevel.PreInit); _modLoader.BroadcastRunLevel(ModRunLevel.Init); - _serializer.Initialize(); _userInterfaceManager.Initialize(); _networkManager.Initialize(false); + _serializer.Initialize(); _inputManager.Initialize(); _console.Initialize(); _prototypeManager.LoadDirectory(new ResourcePath(@"/Prototypes/")); diff --git a/Robust.Server/GameStates/ServerGameStateManager.cs b/Robust.Server/GameStates/ServerGameStateManager.cs index 12941e2aa..c1de29a8c 100644 --- a/Robust.Server/GameStates/ServerGameStateManager.cs +++ b/Robust.Server/GameStates/ServerGameStateManager.cs @@ -114,15 +114,16 @@ namespace Robust.Server.GameStates var oldestAck = GameTick.MaxValue; - foreach (var channel in _networkManager.Channels) + + foreach (var session in _playerManager.GetAllPlayers()) { - var session = _playerManager.GetSessionByChannel(channel); - if (session == null || session.Status != SessionStatus.InGame) + if (session.Status != SessionStatus.InGame) { - // client still joining, maybe iterate over sessions instead? continue; } + var channel = session.ConnectedClient; + if (!_ackedStates.TryGetValue(channel.ConnectionId, out var lastAck)) { DebugTools.Assert("Why does this channel not have an entry?"); diff --git a/Robust.Server/Maps/MapLoader.cs b/Robust.Server/Maps/MapLoader.cs index 1d63dec7d..4afcffc3b 100644 --- a/Robust.Server/Maps/MapLoader.cs +++ b/Robust.Server/Maps/MapLoader.cs @@ -871,6 +871,7 @@ namespace Robust.Server.Maps RootNode = stream.Documents[0].RootNode; GridCount = ((YamlSequenceNode)RootNode["grids"]).Children.Count; + RobustSerializer.MappedStringSerializer.AddStrings(stream, "anonymous map YAML stream"); } } } diff --git a/Robust.Server/Player/PlayerManager.cs b/Robust.Server/Player/PlayerManager.cs index a41e076c8..05775675e 100644 --- a/Robust.Server/Player/PlayerManager.cs +++ b/Robust.Server/Player/PlayerManager.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; +using System.Threading.Tasks; using Prometheus; using Robust.Server.Interfaces; using Robust.Server.Interfaces.Player; @@ -359,18 +360,25 @@ namespace Robust.Server.Player Dirty(); } - private void HandleWelcomeMessageReq(MsgServerInfoReq message) + private async void HandleWelcomeMessageReq(MsgServerInfoReq message) { - var session = GetSessionByChannel(message.MsgChannel); - - var netMsg = message.MsgChannel.CreateNetMessage(); + var channel = message.MsgChannel; + var netMsg = channel.CreateNetMessage(); netMsg.ServerName = _baseServer.ServerName; netMsg.ServerMaxPlayers = _baseServer.MaxPlayers; netMsg.TickRate = _timing.TickRate; + + IPlayerSession session; + while (!TryGetSessionByChannel(channel, out session)) + { + await Task.Delay(10); + if (!channel.IsConnected) return; + } + netMsg.PlayerSessionId = session.SessionId; - message.MsgChannel.SendMessage(netMsg); + channel.SendMessage(netMsg); } private void HandlePlayerListReq(MsgPlayerListReq message) diff --git a/Robust.Shared/ContentPack/DirLoader.cs b/Robust.Shared/ContentPack/DirLoader.cs index 268588e61..a3ce21f07 100644 --- a/Robust.Shared/ContentPack/DirLoader.cs +++ b/Robust.Shared/ContentPack/DirLoader.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; +using System.Reflection; using System.Threading.Tasks; using Robust.Shared.Interfaces.Log; using Robust.Shared.Utility; @@ -125,6 +126,38 @@ namespace Robust.Shared.ContentPack } }); } + public IEnumerable GetRelativeFilePaths() + { + return GetRelativeFilePaths(_directory); + } + + private IEnumerable GetRelativeFilePaths(DirectoryInfo dir) + { + foreach (var file in dir.EnumerateFiles()) + { + if ((file.Attributes & FileAttributes.Hidden) != 0 || file.Name[0] == '.') + { + continue; + } + + var filePath = file.FullName; + var relPath = filePath.Substring(_directory.FullName.Length); + yield return ResourcePath.FromRelativeSystemPath(relPath).ToRootedPath().ToString(); + } + + foreach (var subDir in dir.EnumerateDirectories()) + { + if (((subDir.Attributes & FileAttributes.Hidden) != 0) || subDir.Name[0] == '.') + { + continue; + } + + foreach (var relPath in GetRelativeFilePaths(subDir)) + { + yield return relPath; + } + } + } } } } diff --git a/Robust.Shared/ContentPack/IContentRoot.cs b/Robust.Shared/ContentPack/IContentRoot.cs index dd0b1c786..5d94d4eaf 100644 --- a/Robust.Shared/ContentPack/IContentRoot.cs +++ b/Robust.Shared/ContentPack/IContentRoot.cs @@ -32,6 +32,12 @@ namespace Robust.Shared.ContentPack /// Directory to search inside of. /// Enumeration of all relative file paths of the files found. IEnumerable FindFiles(ResourcePath path); + + /// + /// Recursively returns relative paths to resource files. + /// + /// Enumeration of all relative file paths. + IEnumerable GetRelativeFilePaths(); } } } diff --git a/Robust.Shared/ContentPack/IModLoader.cs b/Robust.Shared/ContentPack/IModLoader.cs index c35350580..e690756ee 100644 --- a/Robust.Shared/ContentPack/IModLoader.cs +++ b/Robust.Shared/ContentPack/IModLoader.cs @@ -1,4 +1,6 @@ +using System.Collections.Generic; using System.IO; +using System.Reflection; using Robust.Shared.Interfaces.Resources; using Robust.Shared.Timing; diff --git a/Robust.Shared/ContentPack/PackLoader.cs b/Robust.Shared/ContentPack/PackLoader.cs index 09239e1c5..4f0e3bf15 100644 --- a/Robust.Shared/ContentPack/PackLoader.cs +++ b/Robust.Shared/ContentPack/PackLoader.cs @@ -90,6 +90,20 @@ namespace Robust.Shared.ContentPack } } } + + public IEnumerable GetRelativeFilePaths() + { + foreach (var entry in _zip.Entries) + { + if (entry.Name == "") + { + // Dir node. + continue; + } + + yield return new ResourcePath(entry.FullName).ToRootedPath().ToString(); + } + } } } } diff --git a/Robust.Shared/ContentPack/ResourceManager.SingleStreamLoader.cs b/Robust.Shared/ContentPack/ResourceManager.SingleStreamLoader.cs index 76d356e51..004133255 100644 --- a/Robust.Shared/ContentPack/ResourceManager.SingleStreamLoader.cs +++ b/Robust.Shared/ContentPack/ResourceManager.SingleStreamLoader.cs @@ -47,6 +47,11 @@ namespace Robust.Shared.ContentPack yield return _resourcePath; } } + + public IEnumerable GetRelativeFilePaths() + { + yield return _resourcePath.ToString(); + } } } } diff --git a/Robust.Shared/ContentPack/ResourceManager.cs b/Robust.Shared/ContentPack/ResourceManager.cs index 2afbe6016..1e79a3969 100644 --- a/Robust.Shared/ContentPack/ResourceManager.cs +++ b/Robust.Shared/ContentPack/ResourceManager.cs @@ -5,9 +5,11 @@ using System.IO; using System.Text.RegularExpressions; using System.Threading; using Robust.Shared.Interfaces.Configuration; +using Robust.Shared.Interfaces.Network; using Robust.Shared.Interfaces.Resources; using Robust.Shared.IoC; using Robust.Shared.Log; +using Robust.Shared.Serialization; using Robust.Shared.Utility; namespace Robust.Shared.ContentPack diff --git a/Robust.Shared/GameObjects/ComponentFactory.cs b/Robust.Shared/GameObjects/ComponentFactory.cs index 7f21fdd83..b90ba3c25 100644 --- a/Robust.Shared/GameObjects/ComponentFactory.cs +++ b/Robust.Shared/GameObjects/ComponentFactory.cs @@ -213,6 +213,11 @@ namespace Robust.Shared.GameObjects return _typeFactory.CreateInstance(GetRegistration(componentName).Type); } + public IComponent GetComponent(uint netId) + { + return _typeFactory.CreateInstance(GetRegistration(netId).Type); + } + public IComponentRegistration GetRegistration(string componentName) { try diff --git a/Robust.Shared/Interfaces/GameObjects/IComponentFactory.cs b/Robust.Shared/Interfaces/GameObjects/IComponentFactory.cs index ac31237ad..3cb521736 100644 --- a/Robust.Shared/Interfaces/GameObjects/IComponentFactory.cs +++ b/Robust.Shared/Interfaces/GameObjects/IComponentFactory.cs @@ -104,6 +104,13 @@ namespace Robust.Shared.Interfaces.GameObjects /// A Component IComponent GetComponent(string componentName); + /// + /// Gets a new component instantiated of the specified network ID. + /// + /// net id of component to make + /// A Component + IComponent GetComponent(uint netId); + /// /// Gets the registration belonging to a component. /// diff --git a/Robust.Shared/Interfaces/Serialization/IRobustSerializer.cs b/Robust.Shared/Interfaces/Serialization/IRobustSerializer.cs index 85ec6905d..26836e14f 100644 --- a/Robust.Shared/Interfaces/Serialization/IRobustSerializer.cs +++ b/Robust.Shared/Interfaces/Serialization/IRobustSerializer.cs @@ -1,5 +1,7 @@ using System; using System.IO; +using System.Threading.Tasks; +using Robust.Shared.Interfaces.Network; namespace Robust.Shared.Interfaces.Serialization { @@ -10,5 +12,9 @@ namespace Robust.Shared.Interfaces.Serialization T Deserialize(Stream stream); object Deserialize(Stream stream); bool CanSerialize(Type type); + + Task Handshake(INetChannel sender); + + event Action ClientHandshakeComplete; } } diff --git a/Robust.Shared/Localization/LocalizationManager.cs b/Robust.Shared/Localization/LocalizationManager.cs index 53bb4ad12..57aec23f6 100644 --- a/Robust.Shared/Localization/LocalizationManager.cs +++ b/Robust.Shared/Localization/LocalizationManager.cs @@ -7,6 +7,7 @@ using NGettext; using Robust.Shared.Interfaces.Resources; using Robust.Shared.IoC; using Robust.Shared.Localization.Macros; +using Robust.Shared.Serialization; using Robust.Shared.Utility; using YamlDotNet.RepresentationModel; @@ -163,6 +164,8 @@ namespace Robust.Shared.Localization { _readEntry(entry, catalog); } + + RobustSerializer.MappedStringSerializer.AddStrings(yamlStream, filePath.ToString()); } private static void _readEntry(YamlMappingNode entry, Catalog catalog) diff --git a/Robust.Shared/Network/NetChannel.cs b/Robust.Shared/Network/NetChannel.cs index bb132ad54..aad9717c4 100644 --- a/Robust.Shared/Network/NetChannel.cs +++ b/Robust.Shared/Network/NetChannel.cs @@ -58,6 +58,12 @@ namespace Robust.Shared.Network /// public void SendMessage(NetMessage message) { + if (_manager.IsClient) + { + _manager.ClientSendMessage(message); + return; + } + _manager.ServerSendMessage(message, this); } diff --git a/Robust.Shared/Network/NetManager.cs b/Robust.Shared/Network/NetManager.cs index ef9bc3b0e..a331c149b 100644 --- a/Robust.Shared/Network/NetManager.cs +++ b/Robust.Shared/Network/NetManager.cs @@ -14,13 +14,14 @@ using Prometheus; using Robust.Shared.Configuration; using Robust.Shared.Interfaces.Configuration; using Robust.Shared.Interfaces.Network; +using Robust.Shared.Interfaces.Serialization; using Robust.Shared.IoC; using Robust.Shared.Log; using Robust.Shared.Utility; namespace Robust.Shared.Network { - /// + /// /// Callback for registered NetMessages. /// /// The message received. @@ -37,6 +38,9 @@ namespace Robust.Shared.Network /// public partial class NetManager : IClientNetManager, IServerNetManager, IDisposable { + + [Dependency] private readonly IRobustSerializer _serializer = default!; + private static readonly Counter SentPacketsMetrics = Metrics.CreateCounter( "robust_net_sent_packets", "Number of packets sent since server startup."); @@ -226,7 +230,15 @@ namespace Robust.Shared.Network _config.RegisterCVar("net.fakelagrand", 0.0f, CVar.CHEAT, _fakeLagRandomChanged); #endif - _strings.Initialize(this, () => { OnConnected(ServerChannel!); }); + _strings.Initialize(this, () => + { + Logger.InfoS("net","Message string table loaded."); + }); + _serializer.ClientHandshakeComplete += () => + { + Logger.InfoS("net","Client completed serializer handshake."); + OnConnected(ServerChannel!); + }; _initialized = true; } @@ -621,7 +633,7 @@ namespace Robust.Shared.Network HandleInitialHandshakeComplete(connection); } - private void HandleInitialHandshakeComplete(NetConnection sender) + private async void HandleInitialHandshakeComplete(NetConnection sender) { var session = _assignedSessions[sender]; @@ -630,6 +642,8 @@ namespace Robust.Shared.Network _strings.SendFullTable(channel); + await _serializer.Handshake(channel); + Logger.InfoS("net", $"{channel.RemoteEndPoint}: Connected"); OnConnected(channel); @@ -666,13 +680,22 @@ namespace Robust.Shared.Network { var peer = msg.SenderConnection.Peer; if (peer.Status == NetPeerStatus.ShutdownRequested) + { + Logger.WarningS("net", $"{msg.SenderConnection.RemoteEndPoint}: Received data message, but shutdown is requested."); return true; + } if (peer.Status == NetPeerStatus.NotRunning) + { + Logger.WarningS("net", $"{msg.SenderConnection.RemoteEndPoint}: Received data message, peer is not running."); return true; + } if (!IsConnected) + { + Logger.WarningS("net", $"{msg.SenderConnection.RemoteEndPoint}: Received data message, but not connected."); return true; + } if (_awaitingData.TryGetValue(msg.SenderConnection, out var info)) { diff --git a/Robust.Shared/Network/StringTable.cs b/Robust.Shared/Network/StringTable.cs index e688fbd4c..5d63010ea 100644 --- a/Robust.Shared/Network/StringTable.cs +++ b/Robust.Shared/Network/StringTable.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using Lidgren.Network; using Robust.Shared.Interfaces.Network; +using Robust.Shared.Log; using Robust.Shared.Utility; namespace Robust.Shared.Network @@ -55,6 +56,7 @@ namespace Robust.Shared.Network { if (_network.IsServer) // Server does not receive entries from clients. return; + Logger.InfoS("net",$"Received message name string table."); foreach (var entry in message.Entries) { @@ -245,6 +247,7 @@ namespace Robust.Shared.Network } + Logger.InfoS("net",$"Sending message name string table to {channel.RemoteEndPoint.Address}."); _network.ServerSendMessage(message, channel); } } diff --git a/Robust.Shared/Prototypes/PrototypeManager.cs b/Robust.Shared/Prototypes/PrototypeManager.cs index be1bd5df1..d901924b5 100644 --- a/Robust.Shared/Prototypes/PrototypeManager.cs +++ b/Robust.Shared/Prototypes/PrototypeManager.cs @@ -10,6 +10,7 @@ using Robust.Shared.Interfaces.Resources; using Robust.Shared.IoC; using Robust.Shared.IoC.Exceptions; using Robust.Shared.Log; +using Robust.Shared.Serialization; using Robust.Shared.Utility; using YamlDotNet.Core; using YamlDotNet.RepresentationModel; @@ -211,7 +212,11 @@ namespace Robust.Shared.Prototypes var yamlStream = new YamlStream(); yamlStream.Load(reader); - return ((YamlStream? yamlStream, ResourcePath?))(yamlStream, filePath); + var result = ((YamlStream? yamlStream, ResourcePath?))(yamlStream, filePath); + + RobustSerializer.MappedStringSerializer.AddStrings(yamlStream, filePath.ToString()); + + return result; } catch (YamlException e) { @@ -255,6 +260,9 @@ namespace Robust.Shared.Prototypes throw new PrototypeLoadException(string.Format("Failed to load prototypes from document#{0}", i), e); } } + + + RobustSerializer.MappedStringSerializer.AddStrings(yaml, "anonymous prototypes YAML stream"); } #endregion IPrototypeManager members diff --git a/Robust.Shared/Robust.Shared.csproj b/Robust.Shared/Robust.Shared.csproj index a5e778640..bf1507b72 100644 --- a/Robust.Shared/Robust.Shared.csproj +++ b/Robust.Shared/Robust.Shared.csproj @@ -35,6 +35,21 @@ + + RobustSerializer.MappedStringSerializer.cs + + + RobustSerializer.MappedStringSerializer.cs + + + RobustSerializer.MappedStringSerializer.cs + + + RobustSerializer.cs + + + RobustSerializer.cs + diff --git a/Robust.Shared/Serialization/RobustSerializer.Handshake.cs b/Robust.Shared/Serialization/RobustSerializer.Handshake.cs new file mode 100644 index 000000000..fdc678f8e --- /dev/null +++ b/Robust.Shared/Serialization/RobustSerializer.Handshake.cs @@ -0,0 +1,35 @@ +using System; +using System.Threading.Tasks; +using Lidgren.Network; +using Robust.Shared.Interfaces.Network; + +namespace Robust.Shared.Serialization +{ + + public partial class RobustSerializer + { + /// + /// Initiates any sequence of handshake extensions that + /// need to occur before the serializer is initialized + /// for a given client. + /// + /// + /// + public Task Handshake(INetChannel channel) + => MappedStringSerializer.Handshake(channel); + + /// + /// An event that occurs once all handshake extensions have + /// completed for the client. + /// + /// Note: This should not occur on the server. + /// + public event Action ClientHandshakeComplete + { + add => MappedStringSerializer.ClientHandshakeComplete += value; + remove => MappedStringSerializer.ClientHandshakeComplete -= value; + } + + } + +} diff --git a/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgClientHandshake.cs b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgClientHandshake.cs new file mode 100644 index 000000000..374f0052d --- /dev/null +++ b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgClientHandshake.cs @@ -0,0 +1,54 @@ +using JetBrains.Annotations; +using Lidgren.Network; +using Robust.Shared.Interfaces.Network; +using Robust.Shared.Network; + +namespace Robust.Shared.Serialization +{ + + public partial class RobustSerializer + { + + public partial class MappedStringSerializer + { + /// + /// The client part of the string-exchange handshake, sent after the + /// client receives the mapping hash and after the client receives a + /// strings package. Tells the server if the client needs an updated + /// copy of the mapping. + /// + /// + /// Also sent by the client after a new copy of the string mapping + /// has been received. If successfully loaded, the value of + /// is false, otherwise it will be + /// true. + /// + /// + [UsedImplicitly] + private class MsgClientHandshake : NetMessage + { + + public MsgClientHandshake(INetChannel ch) + : base($"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgClientHandshake)}", MsgGroups.Core) + { + } + + /// + /// true if the client needs a new copy of the mapping, + /// false otherwise. + /// + public bool NeedsStrings { get; set; } + + public override void ReadFromBuffer(NetIncomingMessage buffer) + => NeedsStrings = buffer.ReadBoolean(); + + public override void WriteToBuffer(NetOutgoingMessage buffer) + => buffer.Write(NeedsStrings); + + } + + } + + } + +} diff --git a/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgServerHandshake.cs b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgServerHandshake.cs new file mode 100644 index 000000000..132adebd5 --- /dev/null +++ b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgServerHandshake.cs @@ -0,0 +1,65 @@ +using System; +using JetBrains.Annotations; +using Lidgren.Network; +using Robust.Shared.Interfaces.Network; +using Robust.Shared.Network; + +namespace Robust.Shared.Serialization +{ + + public partial class RobustSerializer + { + + public partial class MappedStringSerializer + { + + /// + /// The server part of the string-exchange handshake. Sent as the + /// first message in the handshake. Tells the client the hash of + /// the current string mapping, so the client can check if it has + /// a local copy. + /// + /// + [UsedImplicitly] + private class MsgServerHandshake : NetMessage + { + + public MsgServerHandshake(INetChannel ch) + : base($"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgServerHandshake)}", MsgGroups.Core) + { + } + + /// + /// The hash of the current string mapping held by the server. + /// + public byte[]? Hash { get; set; } + + public override void ReadFromBuffer(NetIncomingMessage buffer) + { + var len = buffer.ReadVariableInt32(); + if (len > 64) + { + throw new InvalidOperationException("Hash too long."); + } + + Hash = buffer.ReadBytes(len); + } + + public override void WriteToBuffer(NetOutgoingMessage buffer) + { + if (Hash == null) + { + throw new InvalidOperationException("Package has not been specified."); + } + + buffer.WriteVariableInt32(Hash.Length); + buffer.Write(Hash); + } + + } + + } + + } + +} diff --git a/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgStrings.cs b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgStrings.cs new file mode 100644 index 000000000..e55ef691e --- /dev/null +++ b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgStrings.cs @@ -0,0 +1,72 @@ +using System; +using System.IO; +using JetBrains.Annotations; +using Lidgren.Network; +using Robust.Shared.Interfaces.Network; +using Robust.Shared.Network; + +namespace Robust.Shared.Serialization +{ + + public partial class RobustSerializer + { + + public partial class MappedStringSerializer + { + + /// + /// The meat of the string-exchange handshake sandwich. Sent by the + /// server after the client requests an updated copy of the mapping. + /// Contains the updated string mapping. + /// + /// + [UsedImplicitly] + private class MsgStrings : NetMessage + { + + public MsgStrings(INetChannel ch) + : base($"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgStrings)}", MsgGroups.Core) + { + } + + /// + /// The raw bytes of the string mapping held by the server. + /// + public byte[]? Package { get; set; } + + public override void ReadFromBuffer(NetIncomingMessage buffer) + { + var l = buffer.ReadVariableInt32(); + var success = buffer.ReadBytes(l, out var buf); + if (!success) + { + throw new InvalidDataException("Not all of the bytes were available in the message."); + } + + Package = buf; + } + + public override void WriteToBuffer(NetOutgoingMessage buffer) + { + if (Package == null) + { + throw new InvalidOperationException("Package has not been specified."); + } + + buffer.WriteVariableInt32(Package.Length); + var start = buffer.LengthBytes; + buffer.Write(Package); + var added = buffer.LengthBytes - start; + if (added != Package.Length) + { + throw new InvalidOperationException("Not all of the bytes were written to the message."); + } + } + + } + + } + + } + +} diff --git a/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.cs b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.cs new file mode 100644 index 000000000..9c7d1d1e9 --- /dev/null +++ b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.cs @@ -0,0 +1,1295 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.IO; +using System.IO.Compression; +using System.Linq; +using System.Reflection; +using System.Reflection.Metadata; +using System.Reflection.Metadata.Ecma335; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Security.Cryptography; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using JetBrains.Annotations; +using NetSerializer; +using Newtonsoft.Json.Linq; +using Robust.Shared.ContentPack; +using Robust.Shared.Interfaces.Log; +using Robust.Shared.Interfaces.Network; +using Robust.Shared.IoC; +using Robust.Shared.Log; +using Robust.Shared.Utility; +using YamlDotNet.RepresentationModel; + +namespace Robust.Shared.Serialization +{ + + public partial class RobustSerializer + { + + /// + /// Serializer which manages a mapping of pre-loaded strings to constant + /// values, for message compression. The mapping is shared between the + /// server and client. + /// + /// + /// Strings are long and expensive to send over the wire, and lots of + /// strings involved in messages are sent repeatedly between the server + /// and client - such as filenames, icon states, constant strings, etc. + /// + /// To compress these strings, we use a constant string mapping, decided + /// by the server when it starts up, that associates strings with a + /// fixed value. The mapping is shared with clients when they connect. + /// + /// When sending these strings over the wire, the serializer can then + /// send the constant value instead - and at the other end, the + /// serializer can use the same mapping to recover the original string. + /// + public partial class MappedStringSerializer : IStaticTypeSerializer + { + + private static INetManager? _net; + + private static readonly ISawmill LogSzr = Logger.GetSawmill("szr"); + + private static readonly HashSet IncompleteHandshakes = new HashSet(); + + /// + /// Starts the handshake from the server end of the given channel, + /// sending a . + /// + /// The network channel to perform the handshake over. + /// + /// Locks the string mapping if this is the first time the server is + /// performing the handshake. + /// + /// + /// + public static async Task Handshake(INetChannel channel) + { + var net = channel.NetPeer; + + if (net.IsClient) + { + return; + } + + if (!LockMappedStrings) + { + LockMappedStrings = true; + LogSzr.Info($"Locked in at {_MappedStrings.Count} mapped strings."); + } + + IncompleteHandshakes.Add(channel); + + var message = net.CreateNetMessage(); + message.Hash = MappedStringsHash; + net.ServerSendMessage(message, channel); + + while (IncompleteHandshakes.Contains(channel)) + { + await Task.Delay(1); + } + + LogSzr.Info($"Completed handshake with {channel.RemoteEndPoint.Address}."); + } + + /// + /// Performs the setup so that the serializer can perform the string- + /// exchange protocol. + /// + /// + /// The string-exchange protocol is started by the server when the + /// client first connects. The server sends the client a hash of the + /// string mapping; the client checks that hash against any local + /// caches; and if necessary, the client requests a new copy of the + /// mapping from the server. + /// + /// Uncached flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | + /// + /// + /// Cached flow: + /// Client | Server + /// | <-------------- Hash | + /// | Dont Need Strings -> | + /// + /// + /// Verification failure flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | + /// | <----------- Strings | + /// + Hash Failed | + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | + /// + /// + /// NOTE: Verification failure flow is currently not implemented. + /// + /// + /// The to perform the protocol steps over. + /// + /// + /// + /// + /// + /// + /// + /// + public static void NetworkInitialize(INetManager net) + { + _net = net; + + net.RegisterNetMessage( + $"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgServerHandshake)}", + msg => HandleServerHandshake(net, msg)); + + net.RegisterNetMessage( + $"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgClientHandshake)}", + msg => HandleClientHandshake(net, msg)); + + net.RegisterNetMessage( + $"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgStrings)}", + msg => HandleStringsMessage(net, msg)); + } + + /// + /// Handles the reception, verification of a strings package + /// and subsequent mapping of strings and initiator of + /// receipt response. + /// + /// Uncached flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | <- you are here on client + /// + /// Verification failure flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | + /// | <----------- Strings | + /// + Hash Failed | <- you are here on client + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | <- you are here on client + /// + /// + /// NOTE: Verification failure flow is currently not implemented. + /// + /// + /// Unable to verify strings package by hash. + /// + private static void HandleStringsMessage(INetManager net, MsgStrings msg) + { + if (net.IsServer) + { + LogSzr.Error("Received strings from client."); + return; + } + + LockMappedStrings = false; + ClearStrings(); + DebugTools.Assert(msg.Package != null, "msg.Package != null"); + LoadStrings(new MemoryStream(msg.Package!, false)); + var checkHash = CalculateHash(msg.Package!); + if (!checkHash.SequenceEqual(ServerHash)) + { + // TODO: retry sending MsgClientHandshake with NeedsStrings = false + throw new InvalidOperationException("Unable to verify strings package by hash." + $"\n{ConvertToBase64Url(checkHash)} vs. {ConvertToBase64Url(ServerHash)}"); + } + + _stringMapHash = ServerHash; + LockMappedStrings = true; + + LogSzr.Info($"Locked in at {_MappedStrings.Count} mapped strings."); + + WriteStringCache(); + + // ok we're good now + var channel = msg.MsgChannel; + OnClientCompleteHandshake(net, channel); + } + + /// + /// Interpret a client's handshake, either sending a package + /// of strings or completing the handshake. + /// + /// Uncached flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | <- you are here on server + /// | <----------- Strings | + /// | Dont Need Strings -> | <- you are here on server + /// + /// + /// Cached flow: + /// Client | Server + /// | <-------------- Hash | + /// | Dont Need Strings -> | <- you are here on server + /// + /// + /// Verification failure flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | <- you are here on server + /// | <----------- Strings | + /// + Hash Failed | + /// | Need Strings ------> | <- you are here on server + /// | <----------- Strings | + /// | Dont Need Strings -> | + /// + /// + /// NOTE: Verification failure flow is currently not implemented. + /// + /// + private static void HandleClientHandshake(INetManager net, MsgClientHandshake msg) + { + if (net.IsClient) + { + LogSzr.Error("Received client handshake on client."); + return; + } + + LogSzr.Info($"Received handshake from {msg.MsgChannel.RemoteEndPoint.Address}."); + + if (!msg.NeedsStrings) + { + LogSzr.Info($"Completing handshake with {msg.MsgChannel.RemoteEndPoint.Address}."); + IncompleteHandshakes.Remove(msg.MsgChannel); + return; + } + + // TODO: count and limit number of requests to send strings during handshake + + var strings = msg.MsgChannel.NetPeer.CreateNetMessage(); + using (var ms = new MemoryStream()) + { + WriteStringPackage(ms); + ms.Position = 0; + strings.Package = ms.ToArray(); + LogSzr.Info($"Sending {ms.Length} bytes sized mapped strings package to {msg.MsgChannel.RemoteEndPoint.Address}."); + } + + msg.MsgChannel.SendMessage(strings); + } + + /// + /// Interpret a server's handshake, either requesting a package + /// of strings or completing the handshake. + /// + /// Uncached flow: + /// Client | Server + /// | <-------------- Hash | <- you are here on client + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | + /// + /// + /// Cached flow: + /// Client | Server + /// | <-------------- Hash | <- you are here on client + /// | Dont Need Strings -> | + /// + /// + /// Verification failure flow: + /// Client | Server + /// | <-------------- Hash | <- you are here on client + /// | Need Strings ------> | + /// | <----------- Strings | + /// + Hash Failed | + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | + /// + /// + /// NOTE: Verification failure flow is currently not implemented. + /// + /// Mapped strings are locked. + /// + private static void HandleServerHandshake(INetManager net, MsgServerHandshake msg) + { + if (net.IsServer) + { + LogSzr.Error("Received server handshake on server."); + return; + } + + ServerHash = msg.Hash; + LockMappedStrings = false; + + if (LockMappedStrings) + { + throw new InvalidOperationException("Mapped strings are locked."); + } + + ClearStrings(); + + var hashStr = ConvertToBase64Url(Convert.ToBase64String(msg.Hash!)); + + LogSzr.Info($"Received server handshake with hash {hashStr}."); + + var fileName = CacheForHash(hashStr); + if (!File.Exists(fileName)) + { + LogSzr.Info($"No string cache for {hashStr}."); + var handshake = net.CreateNetMessage(); + LogSzr.Info("Asking server to send mapped strings."); + handshake.NeedsStrings = true; + msg.MsgChannel.SendMessage(handshake); + } + else + { + LogSzr.Info($"We had a cached string map that matches {hashStr}."); + using var file = File.OpenRead(fileName); + var added = LoadStrings(file); + + _stringMapHash = msg.Hash!; + LogSzr.Info($"Read {added} strings from cache {hashStr}."); + LockMappedStrings = true; + LogSzr.Info($"Locked in at {_MappedStrings.Count} mapped strings."); + // ok we're good now + var channel = msg.MsgChannel; + OnClientCompleteHandshake(net, channel); + } + } + + /// + /// Inform the server that the client has a complete copy of the + /// mapping, and alert other code that the handshake is over. + /// + /// + /// + private static void OnClientCompleteHandshake(INetManager net, INetChannel channel) + { + LogSzr.Info("Letting server know we're good to go."); + var handshake = net.CreateNetMessage(); + handshake.NeedsStrings = false; + channel.SendMessage(handshake); + + if (ClientHandshakeComplete == null) + { + LogSzr.Warning("There's no handler attached to ClientHandshakeComplete."); + } + + ClientHandshakeComplete?.Invoke(); + } + + /// + /// Gets the cache file associated with the given hash. + /// + /// The hash to look up the cache for. + /// + /// The filename where the cache for the given hash would be. The + /// file itself may or may not exist. If it does not exist, no cache + /// was made for the given hash. + /// + private static string CacheForHash(string hashStr) + => PathHelpers.ExecutableRelativeFile($"strings-{hashStr}"); + + /// + /// Saves the string cache to a file based on it's hash. + /// + private static void WriteStringCache() + { + var hashStr = Convert.ToBase64String(MappedStringsHash); + hashStr = ConvertToBase64Url(hashStr); + + var fileName = CacheForHash(hashStr); + using var file = File.OpenWrite(fileName); + WriteStringPackage(file); + + LogSzr.Info($"Wrote string cache {hashStr}."); + } + + private static byte[]? _mappedStringsPackage; + + private static byte[] MappedStringsPackage => LockMappedStrings + ? _mappedStringsPackage ??= WriteStringPackage() + : throw new InvalidOperationException("Mapped strings must be locked."); + + /// + /// Writes strings to a package and converts to an array of bytes. + /// + /// + /// This is invoked by accessing for the first time. + /// + private static byte[] WriteStringPackage() + { + using var ms = new MemoryStream(); + WriteStringPackage(ms); + return ms.ToArray(); + } + + /// + /// Strings longer than this will throw an exception and a better strategy will need to be employed to deal with large strings. + /// + public static int StringPackageMaximumBufferSize = 65536; + + /// + /// Writes a strings package to a stream. + /// + /// A writable stream. + /// Overly long string in strings package. + public static void WriteStringPackage(Stream stream) + { + var buf = new byte[StringPackageMaximumBufferSize]; + var sw = Stopwatch.StartNew(); + var enc = Encoding.UTF8.GetEncoder(); + + using (var zs = new DeflateStream(stream, CompressionLevel.Optimal, true)) + { + var bytesWritten = WriteCompressedUnsignedInt(zs, (uint) MappedStrings.Count); + foreach (var str in MappedStrings) + { + if (str.Length >= StringPackageMaximumBufferSize) + { + throw new NotImplementedException("Overly long string in strings package."); + } + + var l = enc.GetBytes(str, buf, true); + + if (l >= StringPackageMaximumBufferSize) + { + throw new NotImplementedException("Overly long string in strings package."); + } + + bytesWritten += WriteCompressedUnsignedInt(zs, (uint) l); + + zs.Write(buf, 0, l); + + bytesWritten += l; + + enc.Reset(); + } + + zs.Write(BitConverter.GetBytes(bytesWritten)); + zs.Flush(); + } + + LogSzr.Info($"Wrote {MappedStrings.Count} strings to package in {sw.ElapsedMilliseconds}ms."); + } + + /// + /// Loads a strings package from a stream. + /// + /// + /// Uses to extract strings and adds them to the mapping. + /// + /// A readable stream. + /// The number of strings loaded. + /// Mapped strings are locked, will not load. + /// Did not read all bytes in package! + private static int LoadStrings(Stream stream) + { + if (LockMappedStrings) + { + throw new InvalidOperationException("Mapped strings are locked, will not load."); + } + + var started = MappedStrings.Count; + foreach (var str in ReadStringPackage(stream)) + { + _StringMapping[str] = _MappedStrings.Count; + _MappedStrings.Add(str); + } + + if (stream.CanSeek && stream.CanRead) + { + if (stream.Position != stream.Length) + { + throw new InvalidDataException("Did not read all bytes in package!"); + } + } + + var added = MappedStrings.Count - started; + return added; + } + + /// + /// Reads the contents of a strings package. + /// + /// + /// Does not add strings to the current mapping. + /// + /// A readable stream. + /// Strings from within the package. + /// Could not read the full length of string #N. + private static IEnumerable ReadStringPackage(Stream stream) + { + var buf = ArrayPool.Shared.Rent(65536); + var sw = Stopwatch.StartNew(); + using var zs = new DeflateStream(stream, CompressionMode.Decompress); + + var c = ReadCompressedUnsignedInt(zs, out var x); + var bytesRead = x; + for (var i = 0; i < c; ++i) + { + var l = (int) ReadCompressedUnsignedInt(zs, out x); + bytesRead += x; + var y = zs.Read(buf, 0, l); + if (y != l) + { + throw new InvalidDataException($"Could not read the full length of string #{i}."); + } + + bytesRead += y; + var str = Encoding.UTF8.GetString(buf, 0, l); + yield return str; + } + + zs.Read(buf, 0, 4); + var checkBytesRead = BitConverter.ToInt32(buf, 0); + if (checkBytesRead != bytesRead) + { + throw new InvalidDataException("Could not verify package was read correctly."); + } + + LogSzr.Info($"Read package of {c} strings in {sw.ElapsedMilliseconds}ms."); + } + + /// + /// Converts a byte array such as a hash to a Base64 representation that is URL safe. + /// + /// + /// A base64url string form of the byte array. + private static string ConvertToBase64Url(byte[]? data) + => data == null ? "" : ConvertToBase64Url(Convert.ToBase64String(data)); + + /// + /// Converts a a Base64 string to one that is URL safe. + /// + /// A base64url formed string. + private static string ConvertToBase64Url(string b64Str) + { + if (b64Str is null) + { + throw new ArgumentNullException(nameof(b64Str)); + } + + var cut = b64Str[^1] == '=' ? b64Str[^2] == '=' ? 2 : 1 : 0; + b64Str = new StringBuilder(b64Str).Replace('+', '-').Replace('/', '_').ToString(0, b64Str.Length - cut); + return b64Str; + } + + /// + /// Converts a URL-safe Base64 string into a byte array. + /// + /// A base64url formed string. + /// The represented byte array. + public static byte[] ConvertFromBase64Url(string s) + { + var l = s.Length % 3; + var sb = new StringBuilder(s); + sb.Replace('-', '+').Replace('_', '/'); + for (var i = 0; i < l; ++i) + { + sb.Append('='); + } + + s = sb.ToString(); + return Convert.FromBase64String(s); + } + + public static byte[]? ServerHash; + + private static readonly IList _MappedStrings = new List(); + + private static readonly IDictionary _StringMapping = new Dictionary(); + + public static IReadOnlyList MappedStrings => new ReadOnlyCollection(_MappedStrings); + + /// + /// Whether the string mapping is decided, and cannot be changed. + /// + /// + /// + /// While false, strings can be added to the mapping, but + /// it cannot be saved to a cache. + /// + /// + /// While true, the mapping cannot be modified, but can be + /// shared between the server and client and saved to a cache. + /// + /// + public static bool LockMappedStrings { get; set; } + + private static readonly Regex RxSymbolSplitter + = new Regex( + @"(?<=[^\s\W])(?=[A-Z]) # Match for split at start of new capital letter + |(?<=[^0-9\s\W])(?=[0-9]) # Match for split before spans of numbers + |(?<=[A-Za-z0-9])(?=_) # Match for a split before an underscore + |(?=[.\\\/,#$?!@|&*()^`""'`~[\]{}:;\-]) # Match for a split after symbols + |(?<=[.\\\/,#$?!@|&*()^`""'`~[\]{}:;\-]) # Match for a split before symbols too", + RegexOptions.CultureInvariant + | RegexOptions.Compiled + | RegexOptions.IgnorePatternWhitespace + ); + + /// + /// Add a string to the constant mapping. + /// + /// + /// If the string has multiple detectable subcomponents, such as a + /// filepath, it may result in more than one string being added to + /// the mapping. As string parts are commonly sent as subsets or + /// scoped names, this increases the likelyhood of a successful + /// string mapping. + /// + /// + /// true if the string was added to the mapping for the first + /// time, false otherwise. + /// + /// + /// Thrown if the mapping is locked, and strings cannot be added, or + /// if the string is not normalized (). + /// + public static bool AddString(string str) + { + if (LockMappedStrings) + { + if (_net.IsClient) + { + //LogSzr.Info("On client and mapped strings are locked, will not add."); + return false; + } + + throw new InvalidOperationException("Mapped strings are locked, will not add."); + } + + if (string.IsNullOrEmpty(str)) + { + return false; + } + + if (!str.IsNormalized()) + { + throw new InvalidOperationException("Only normalized strings may be added."); + } + + if (_StringMapping.ContainsKey(str)) + { + return false; + } + + if (str.Length >= MaxMappedStringSize) return false; + + if (str.Length <= MinMappedStringSize) return false; + + str = str.Trim(); + + if (str.Length <= MinMappedStringSize) return false; + + str = str.Replace(Environment.NewLine, "\n"); + + if (str.Length <= MinMappedStringSize) return false; + + var symTrimmedStr = str.Trim(TrimmableSymbolChars); + if (symTrimmedStr != str) + { + AddString(symTrimmedStr); + } + + if (str.Contains('/')) + { + var parts = str.Split('/', StringSplitOptions.RemoveEmptyEntries); + for (var i = 0; i < parts.Length; ++i) + { + for (var l = 1; l <= parts.Length - i; ++l) + { + var subStr = string.Join('/', parts.Skip(i).Take(l)); + if (_StringMapping.TryAdd(subStr, _MappedStrings.Count)) + { + _MappedStrings.Add(subStr); + } + + if (!subStr.Contains('.')) + { + continue; + } + + var subParts = subStr.Split('.', StringSplitOptions.RemoveEmptyEntries); + for (var si = 0; si < subParts.Length; ++si) + { + for (var sl = 1; sl <= subParts.Length - si; ++sl) + { + var subSubStr = string.Join('.', subParts.Skip(si).Take(sl)); + // ReSharper disable once InvertIf + if (_StringMapping.TryAdd(subSubStr, _MappedStrings.Count)) + { + _MappedStrings.Add(subSubStr); + } + } + } + } + } + } + else if (str.Contains("_")) + { + foreach (var substr in str.Split("_")) + { + AddString(substr); + } + } + else if (str.Contains(" ")) + { + foreach (var substr in str.Split(" ")) + { + if (substr == str) continue; + + AddString(substr); + } + } + else + { + var parts = RxSymbolSplitter.Split(str); + foreach (var substr in parts) + { + if (substr == str) continue; + + AddString(substr); + } + + for (var si = 0; si < parts.Length; ++si) + { + for (var sl = 1; sl <= parts.Length - si; ++sl) + { + var subSubStr = string.Concat(parts.Skip(si).Take(sl)); + if (_StringMapping.TryAdd(subSubStr, _MappedStrings.Count)) + { + _MappedStrings.Add(subSubStr); + } + } + } + } + + if (_StringMapping.TryAdd(str, _MappedStrings.Count)) + { + _MappedStrings.Add(str); + } + + _stringMapHash = null; + _mappedStringsPackage = null; + return true; + } + + /// + /// Add the constant strings from an to the + /// mapping. + /// + /// The assembly from which to collect constant strings. + /// + /// Thrown if the mapping is locked. + /// + [MethodImpl(MethodImplOptions.Synchronized)] + public static unsafe void AddStrings(Assembly asm) + { + if (LockMappedStrings) + { + if (_net.IsClient) + { + //LogSzr.Info("On client and mapped strings are locked, will not add."); + return; + } + + throw new InvalidOperationException("Mapped strings are locked, will not add ."); + } + + var started = MappedStrings.Count; + var sw = Stopwatch.StartNew(); + if (asm.TryGetRawMetadata(out var blob, out var len)) + { + var reader = new MetadataReader(blob, len); + var usrStrHandle = default(UserStringHandle); + do + { + var userStr = reader.GetUserString(usrStrHandle); + if (userStr != "") + { + AddString(string.Intern(userStr.Normalize())); + } + + usrStrHandle = reader.GetNextHandle(usrStrHandle); + } while (usrStrHandle != default); + + var strHandle = default(StringHandle); + do + { + var str = reader.GetString(strHandle); + if (str != "") + { + AddString(string.Intern(str.Normalize())); + } + + strHandle = reader.GetNextHandle(strHandle); + } while (strHandle != default); + } + + var added = MappedStrings.Count - started; + LogSzr.Info($"Mapping {added} strings from {asm.GetName().Name} took {sw.ElapsedMilliseconds}ms."); + } + + /// + /// Add strings from the given to the mapping. + /// + /// + /// Strings are taken from YAML anchors, tags, and leaf nodes. + /// + /// The YAML to collect strings from. + /// The stream name. Only used for logging. + /// + /// Thrown if the mapping is locked. + /// + [MethodImpl(MethodImplOptions.Synchronized)] + public static void AddStrings(YamlStream yaml, string name) + { + if (LockMappedStrings) + { + if (_net.IsClient) + { + //LogSzr.Info("On client and mapped strings are locked, will not add."); + return; + } + + throw new InvalidOperationException("Mapped strings are locked, will not add."); + } + + var started = MappedStrings.Count; + var sw = Stopwatch.StartNew(); + foreach (var doc in yaml) + { + foreach (var node in doc.AllNodes) + { + var a = node.Anchor; + if (!string.IsNullOrEmpty(a)) + { + AddString(a); + } + + var t = node.Tag; + if (!string.IsNullOrEmpty(t)) + { + AddString(t); + } + + switch (node) + { + case YamlScalarNode scalar: + { + var v = scalar.Value; + if (string.IsNullOrEmpty(v)) + { + continue; + } + + AddString(v); + break; + } + } + } + } + + var added = MappedStrings.Count - started; + LogSzr.Info($"Mapping {added} strings from {name} took {sw.ElapsedMilliseconds}ms."); + } + + /// + /// Add strings from the given to the mapping. + /// + /// + /// Strings are taken from JSON property names and string nodes. + /// + /// The JSON to collect strings from. + /// The stream name. Only used for logging. + /// + /// Thrown if the mapping is locked. + /// + public static void AddStrings(JObject obj, string name) + { + if (LockMappedStrings) + { + if (_net.IsClient) + { + //LogSzr.Info("On client and mapped strings are locked, will not add."); + return; + } + + throw new InvalidOperationException("Mapped strings are locked, will not add."); + } + + var started = MappedStrings.Count; + var sw = Stopwatch.StartNew(); + foreach (var node in obj.DescendantsAndSelf()) + { + switch (node) + { + case JValue value: + { + if (value.Type != JTokenType.String) + { + continue; + } + + var v = value.Value?.ToString(); + if (string.IsNullOrEmpty(v)) + { + continue; + } + + AddString(v); + break; + } + case JProperty prop: + { + var propName = prop.Name; + if (string.IsNullOrEmpty(propName)) + { + continue; + } + + AddString(propName); + break; + } + } + } + + var added = MappedStrings.Count - started; + LogSzr.Info($"Mapping {added} strings from {name} took {sw.ElapsedMilliseconds}ms."); + } + + /// + /// Remove all strings from the mapping, completely resetting it. + /// + /// + /// Thrown if the mapping is locked. + /// + public static void ClearStrings() + { + if (LockMappedStrings) + { + throw new InvalidOperationException("Mapped strings are locked, will not clear."); + } + + _MappedStrings.Clear(); + _StringMapping.Clear(); + _stringMapHash = null; + } + + /// + /// Add strings from the given enumeration to the mapping. + /// + /// The strings to add. + /// The source provider of the strings to be logged. + /// + /// Thrown if the mapping is locked. + /// + [MethodImpl(MethodImplOptions.Synchronized)] + public static void AddStrings(IEnumerable strings, string providerName) + { + if (LockMappedStrings) + { + if (_net.IsClient) + { + //LogSzr.Info("On client and mapped strings are locked, will not add."); + return; + } + + throw new InvalidOperationException("Mapped strings are locked, will not add."); + } + + var started = MappedStrings.Count; + foreach (var str in strings) + { + AddString(str); + } + + var added = MappedStrings.Count - started; + LogSzr.Info($"Mapping {added} strings from {providerName}."); + } + + private static byte[]? _stringMapHash; + + /// + /// The hash of the string mapping. + /// + /// + /// Thrown if the mapping is not locked. + /// + public static byte[] MappedStringsHash => _stringMapHash ??= CalculateMappedStringsHash(); + + private static byte[] CalculateMappedStringsHash() + { + if (!LockMappedStrings) + { + throw new InvalidOperationException("String table should be locked before attempting to retrieve hash."); + } + + var sw = Stopwatch.StartNew(); + + var hash = CalculateHash(MappedStringsPackage); + + LogSzr.Info($"Hashing {MappedStrings.Count} strings took {sw.ElapsedMilliseconds}ms."); + LogSzr.Info($"Size: {MappedStringsPackage.Length} bytes, Hash: {ConvertToBase64Url(hash)}"); + return hash; + } + + /// + /// Creates a SHA512 hash of the given array of bytes. + /// + /// An array of bytes to be hashed. + /// A 512-bit (64-byte) hash result as an array of bytes. + /// + private static byte[] CalculateHash(byte[] data) + { + if (data is null) + { + throw new ArgumentNullException(nameof(data)); + } + + using var hasher = SHA512.Create(); + var hash = hasher.ComputeHash(data); + return hash; + } + + /// + /// Implements . + /// Specifies that this implementation handles strings. + /// + public bool Handles(Type type) => type == typeof(string); + + /// + /// Implements . + /// + public IEnumerable GetSubtypes(Type type) => Type.EmptyTypes; + + /// + /// Implements . + /// + /// + public MethodInfo GetStaticWriter(Type type) => WriteMappedStringMethodInfo; + + /// + /// Implements . + /// + /// + public MethodInfo GetStaticReader(Type type) => ReadMappedStringMethodInfo; + + private delegate void WriteStringDelegate(Stream stream, string? value); + + private delegate void ReadStringDelegate(Stream stream, out string? value); + + private static readonly MethodInfo WriteMappedStringMethodInfo + = ((WriteStringDelegate) WriteMappedString).Method; + + private static readonly MethodInfo ReadMappedStringMethodInfo + = ((ReadStringDelegate) ReadMappedString).Method; + + private static readonly char[] TrimmableSymbolChars = + { + '.', '\\', '/', ',', '#', '$', '?', '!', '@', '|', '&', '*', '(', ')', '^', '`', '"', '\'', '`', '~', '[', ']', '{', '}', ':', ';', '-' + }; + + /// + /// The shortest a string can be in order to be inserted in the mapping. + /// + /// + /// Strings below a certain length aren't worth compressing. + /// + private const int MinMappedStringSize = 3; + + /// + /// The longest a string can be in order to be inserted in the mapping. + /// + private const int MaxMappedStringSize = 420; + + /// + /// The special value corresponding to a null string in the + /// encoding. + /// + private const int MappedNull = 0; + + /// + /// The special value corresponding to a string which was not mapped. + /// This is followed by the bytes of the unmapped string. + /// + private const int UnmappedString = 1; + + /// + /// The first non-special value, used for encoding mapped strings. + /// + /// + /// Since previous values are taken by and + /// , this value is used to encode + /// mapped strings at an offset - in the encoding, a value + /// >= FirstMappedIndexStart represents the string with + /// mapping of that value - FirstMappedIndexStart. + /// + private const int FirstMappedIndexStart = 2; + + /// + /// Write the encoding of the given string to the stream. + /// + /// The stream to write to. + /// The (possibly null) string to write. + public static void WriteMappedString(Stream stream, string? value) + { + if (!LockMappedStrings) + { + LogSzr.Warning("Performing unlocked string mapping."); + } + + if (value == null) + { + WriteCompressedUnsignedInt(stream, MappedNull); + return; + } + + if (_StringMapping.TryGetValue(value, out var mapping)) + { +#if DEBUG + if (mapping >= _MappedStrings.Count || mapping < 0) + { + throw new InvalidOperationException("A string mapping outside of the mapped string table was encountered."); + } +#endif + WriteCompressedUnsignedInt(stream, (uint) mapping + FirstMappedIndexStart); + //Logger.DebugS("szr", $"Encoded mapped string: {value}"); + return; + } + + // indicate not mapped + WriteCompressedUnsignedInt(stream, UnmappedString); + var buf = Encoding.UTF8.GetBytes(value); + //Logger.DebugS("szr", $"Encoded unmapped string: {value}"); + WriteCompressedUnsignedInt(stream, (uint) buf.Length); + stream.Write(buf); + } + + /// + /// Try to read a string from the given stream.. + /// + /// The stream to read from. + /// The (possibly null) string read. + /// + /// Thrown if the mapping is not locked. + /// + public static void ReadMappedString(Stream stream, out string? value) + { + if (!LockMappedStrings) + { + throw new InvalidOperationException("Not performing unlocked string mapping."); + } + + var mapIndex = ReadCompressedUnsignedInt(stream, out _); + if (mapIndex == MappedNull) + { + value = null; + return; + } + + if (mapIndex == UnmappedString) + { + // not mapped + var length = ReadCompressedUnsignedInt(stream, out _); + var buf = new byte[length]; + stream.Read(buf); + value = Encoding.UTF8.GetString(buf); + //Logger.DebugS("szr", $"Decoded unmapped string: {value}"); + return; + } + + value = _MappedStrings[(int) mapIndex - FirstMappedIndexStart]; + //Logger.DebugS("szr", $"Decoded mapped string: {value}"); + } + +#if ROBUST_SERIALIZER_DISABLE_COMPRESSED_UINTS + public static int WriteCompressedUnsignedInt(Stream stream, uint value) + { + WriteUnsignedInt(stream, value); + return 4; + } + + public static uint ReadCompressedUnsignedInt(Stream stream, out int byteCount) + { + byteCount = 4; + return ReadUnsignedInt(stream); + } +#else + public static int WriteCompressedUnsignedInt(Stream stream, uint value) + { + var length = 1; + while (value >= 0x80) + { + stream.WriteByte((byte) (0x80 | value)); + value >>= 7; + ++length; + } + + stream.WriteByte((byte) value); + return length; + } + + public static uint ReadCompressedUnsignedInt(Stream stream, out int byteCount) + { + byteCount = 0; + var value = 0u; + var shift = 0; + while (stream.CanRead) + { + var current = stream.ReadByte(); + ++byteCount; + if (current == -1) + { + throw new EndOfStreamException(); + } + + value |= (0x7Fu & (byte) current) << shift; + shift += 7; + if ((0x80 & current) == 0) + { + return value; + } + } + + throw new EndOfStreamException(); + } +#endif + + [UsedImplicitly] + public static unsafe void WriteUnsignedInt(Stream stream, uint value) + { + var bytes = MemoryMarshal.AsBytes(new ReadOnlySpan(&value, 1)); + stream.Write(bytes); + } + + [UsedImplicitly] + public static unsafe uint ReadUnsignedInt(Stream stream) + { + uint value; + var bytes = MemoryMarshal.AsBytes(new Span(&value, 1)); + stream.Read(bytes); + return value; + } + + /// + /// See . + /// + public static event Action? ClientHandshakeComplete; + + } + + } + +} diff --git a/Robust.Shared/Serialization/RobustSerializer.cs b/Robust.Shared/Serialization/RobustSerializer.cs index abc1e9c89..fffacd939 100644 --- a/Robust.Shared/Serialization/RobustSerializer.cs +++ b/Robust.Shared/Serialization/RobustSerializer.cs @@ -6,20 +6,27 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.Loader; +using Robust.Shared.Interfaces.Network; namespace Robust.Shared.Serialization { - public class RobustSerializer : IRobustSerializer + public partial class RobustSerializer : IRobustSerializer { - [Dependency] private readonly IReflectionManager reflectionManager = default!; - private Serializer Serializer = default!; + [Dependency] private readonly IReflectionManager _reflectionManager = default!; + [Dependency] private readonly INetManager _netManager = default!; - private HashSet SerializableTypes = default!; -#region Statistics + private Serializer _serializer = default!; + + private HashSet _serializableTypes = default!; + + #region Statistics + public static long LargestObjectSerializedBytes { get; private set; } + public static Type? LargestObjectSerializedType { get; private set; } public static long BytesSerialized { get; private set; } @@ -27,74 +34,119 @@ namespace Robust.Shared.Serialization public static long ObjectsSerialized { get; private set; } public static long LargestObjectDeserializedBytes { get; private set; } + public static Type? LargestObjectDeserializedType { get; private set; } public static long BytesDeserialized { get; private set; } public static long ObjectsDeserialized { get; private set; } -#endregion + + #endregion public void Initialize() { - var types = reflectionManager.FindTypesWithAttribute().ToList(); -#if DEBUG + var mappedStringSerializer = new MappedStringSerializer(); + var types = _reflectionManager.FindTypesWithAttribute().ToList(); +#if !FULL_RELEASE + // confirm only shared types are marked for serialization, no client & server only types foreach (var type in types) { - if (type.Assembly.FullName!.Contains("Server") || type.Assembly.FullName.Contains("Client")) + if (type.Assembly.FullName!.Contains("Server")) { - throw new InvalidOperationException($"Type {type} is server/client specific but has a NetSerializableAttribute!"); + throw new InvalidOperationException($"Type {type} is server specific but has a NetSerializableAttribute!"); + } + + if (type.Assembly.FullName.Contains("Client")) + { + throw new InvalidOperationException($"Type {type} is client specific but has a NetSerializableAttribute!"); } } #endif - var settings = new Settings(); - Serializer = new Serializer(types, settings); - SerializableTypes = new HashSet(Serializer.GetTypeMap().Keys); + var settings = new Settings + { + CustomTypeSerializers = new ITypeSerializer[] {mappedStringSerializer} + }; + _serializer = new Serializer(types, settings); + _serializableTypes = new HashSet(_serializer.GetTypeMap().Keys); + + if (_netManager.IsClient) + { + MappedStringSerializer.LockMappedStrings = true; + } + else + { + var defaultAssemblies = AssemblyLoadContext.Default.Assemblies; + var gameAssemblies = _reflectionManager.Assemblies; + var robustShared = defaultAssemblies + .First(a => a.GetName().Name == "Robust.Shared"); + MappedStringSerializer.AddStrings(robustShared); + + // TODO: Need to add a GetSharedAssemblies method to the reflection manager + + var contentShared = gameAssemblies + .FirstOrDefault(a => a.GetName().Name == "Content.Shared"); + if (contentShared != null) + { + MappedStringSerializer.AddStrings(contentShared); + } + + // TODO: Need to add a GetServerAssemblies method to the reflection manager + + var contentServer = gameAssemblies + .FirstOrDefault(a => a.GetName().Name == "Content.Server"); + if (contentServer != null) + { + MappedStringSerializer.AddStrings(contentServer); + } + } + + MappedStringSerializer.NetworkInitialize(_netManager); } public void Serialize(Stream stream, object toSerialize) { var start = stream.Position; - Serializer.Serialize(stream, toSerialize); + _serializer.Serialize(stream, toSerialize); var end = stream.Position; var byteCount = end - start; BytesSerialized += byteCount; ++ObjectsSerialized; - if (byteCount > LargestObjectSerializedBytes) + if (byteCount <= LargestObjectSerializedBytes) { - LargestObjectSerializedBytes = byteCount; - LargestObjectSerializedType = toSerialize.GetType(); + return; } + + LargestObjectSerializedBytes = byteCount; + LargestObjectSerializedType = toSerialize.GetType(); } public T Deserialize(Stream stream) - { - return (T) Deserialize(stream); - } + => (T) Deserialize(stream); public object Deserialize(Stream stream) { var start = stream.Position; - var result = Serializer.Deserialize(stream); + var result = _serializer.Deserialize(stream); var end = stream.Position; var byteCount = end - start; BytesDeserialized += byteCount; ++ObjectsDeserialized; - if (byteCount > LargestObjectDeserializedBytes) + if (byteCount <= LargestObjectDeserializedBytes) { - LargestObjectDeserializedBytes = byteCount; - LargestObjectDeserializedType = result.GetType(); + return result; } + LargestObjectDeserializedBytes = byteCount; + LargestObjectDeserializedType = result.GetType(); + return result; } public bool CanSerialize(Type type) - { - return SerializableTypes.Contains(type); - } + => _serializableTypes.Contains(type); } diff --git a/Robust.Shared/Serialization/YamlObjectSerializer.cs b/Robust.Shared/Serialization/YamlObjectSerializer.cs index 38adf1f00..895e761ff 100644 --- a/Robust.Shared/Serialization/YamlObjectSerializer.cs +++ b/Robust.Shared/Serialization/YamlObjectSerializer.cs @@ -511,7 +511,7 @@ namespace Robust.Shared.Serialization } else { - throw new YamlException($"Malformed type tag."); + throw new YamlException("Malformed type tag."); } } diff --git a/Robust.Shared/Utility/TypeAbbreviation.cs b/Robust.Shared/Utility/TypeAbbreviation.cs index 6be644020..719c4765d 100644 --- a/Robust.Shared/Utility/TypeAbbreviation.cs +++ b/Robust.Shared/Utility/TypeAbbreviation.cs @@ -1,6 +1,7 @@ using System; using System.IO; using System.Text; +using Robust.Shared.Serialization; using YamlDotNet.RepresentationModel; namespace Robust.Shared.Utility @@ -27,6 +28,8 @@ namespace Robust.Shared.Utility var document = yamlStream.Documents[0]; _abbreviations = ParseAbbreviations((YamlSequenceNode) document.RootNode); + + RobustSerializer.MappedStringSerializer.AddStrings(yamlStream, "(embedded) Robust.Shared.Utility.TypeAbbreviations.yaml"); } /// diff --git a/Robust.UnitTesting/Shared/GameObjects/EntityState_Tests.cs b/Robust.UnitTesting/Shared/GameObjects/EntityState_Tests.cs index 48972bd7a..3630009ed 100644 --- a/Robust.UnitTesting/Shared/GameObjects/EntityState_Tests.cs +++ b/Robust.UnitTesting/Shared/GameObjects/EntityState_Tests.cs @@ -3,11 +3,19 @@ using System.Collections.Generic; using System.IO; using NUnit.Framework; using Robust.Server.Reflection; +using Robust.Shared.Configuration; using Robust.Shared.ContentPack; using Robust.Shared.GameObjects; +using Robust.Shared.GameObjects.Components.Map; +using Robust.Shared.Interfaces.Configuration; +using Robust.Shared.Interfaces.Log; +using Robust.Shared.Interfaces.Network; using Robust.Shared.Interfaces.Reflection; using Robust.Shared.Interfaces.Serialization; using Robust.Shared.IoC; +using Robust.Shared.Log; +using Robust.Shared.Map; +using Robust.Shared.Network; using Robust.Shared.Serialization; namespace Robust.UnitTesting.Shared.GameObjects @@ -23,24 +31,42 @@ namespace Robust.UnitTesting.Shared.GameObjects public void ComponentChangedSerialized() { var container = new DependencyCollection(); + container.Register(); + container.Register(); + container.Register(); container.Register(); container.Register(); container.BuildGraph(); container.Resolve().LoadAssemblies(AppDomain.CurrentDomain.GetAssemblyByName("Robust.Shared")); + IoCManager.InitThread(container); + + container.Resolve().Initialize(true); + var serializer = container.Resolve(); serializer.Initialize(); byte[] array; using(var stream = new MemoryStream()) { - var payload = new EntityState(new EntityUid(512), Array.Empty(), Array.Empty()); + var payload = new EntityState( + new EntityUid(512), + new [] + { + new ComponentChanged(false, NetIDs.MAP_GRID, nameof(MapGridComponent)), + }, + new [] + { + new MapGridComponentState(new GridId(0)), + }); serializer.Serialize(stream, payload); array = stream.ToArray(); } + IoCManager.Clear(); + Assert.Pass($"Size in Bytes: {array.Length.ToString()}"); } }