diff --git a/Robust.Client/GameController.cs b/Robust.Client/GameController.cs index a6f1e7c03..ea2ffe094 100644 --- a/Robust.Client/GameController.cs +++ b/Robust.Client/GameController.cs @@ -160,9 +160,9 @@ namespace Robust.Client _modLoader.BroadcastRunLevel(ModRunLevel.PreInit); _modLoader.BroadcastRunLevel(ModRunLevel.Init); - _serializer.Initialize(); _userInterfaceManager.Initialize(); _networkManager.Initialize(false); + _serializer.Initialize(); _inputManager.Initialize(); _console.Initialize(); _prototypeManager.LoadDirectory(new ResourcePath(@"/Prototypes/")); diff --git a/Robust.Server/GameStates/ServerGameStateManager.cs b/Robust.Server/GameStates/ServerGameStateManager.cs index 12941e2aa..c1de29a8c 100644 --- a/Robust.Server/GameStates/ServerGameStateManager.cs +++ b/Robust.Server/GameStates/ServerGameStateManager.cs @@ -114,15 +114,16 @@ namespace Robust.Server.GameStates var oldestAck = GameTick.MaxValue; - foreach (var channel in _networkManager.Channels) + + foreach (var session in _playerManager.GetAllPlayers()) { - var session = _playerManager.GetSessionByChannel(channel); - if (session == null || session.Status != SessionStatus.InGame) + if (session.Status != SessionStatus.InGame) { - // client still joining, maybe iterate over sessions instead? continue; } + var channel = session.ConnectedClient; + if (!_ackedStates.TryGetValue(channel.ConnectionId, out var lastAck)) { DebugTools.Assert("Why does this channel not have an entry?"); diff --git a/Robust.Server/Maps/MapLoader.cs b/Robust.Server/Maps/MapLoader.cs index 1d63dec7d..4afcffc3b 100644 --- a/Robust.Server/Maps/MapLoader.cs +++ b/Robust.Server/Maps/MapLoader.cs @@ -871,6 +871,7 @@ namespace Robust.Server.Maps RootNode = stream.Documents[0].RootNode; GridCount = ((YamlSequenceNode)RootNode["grids"]).Children.Count; + RobustSerializer.MappedStringSerializer.AddStrings(stream, "anonymous map YAML stream"); } } } diff --git a/Robust.Server/Player/PlayerManager.cs b/Robust.Server/Player/PlayerManager.cs index a41e076c8..05775675e 100644 --- a/Robust.Server/Player/PlayerManager.cs +++ b/Robust.Server/Player/PlayerManager.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; +using System.Threading.Tasks; using Prometheus; using Robust.Server.Interfaces; using Robust.Server.Interfaces.Player; @@ -359,18 +360,25 @@ namespace Robust.Server.Player Dirty(); } - private void HandleWelcomeMessageReq(MsgServerInfoReq message) + private async void HandleWelcomeMessageReq(MsgServerInfoReq message) { - var session = GetSessionByChannel(message.MsgChannel); - - var netMsg = message.MsgChannel.CreateNetMessage(); + var channel = message.MsgChannel; + var netMsg = channel.CreateNetMessage(); netMsg.ServerName = _baseServer.ServerName; netMsg.ServerMaxPlayers = _baseServer.MaxPlayers; netMsg.TickRate = _timing.TickRate; + + IPlayerSession session; + while (!TryGetSessionByChannel(channel, out session)) + { + await Task.Delay(10); + if (!channel.IsConnected) return; + } + netMsg.PlayerSessionId = session.SessionId; - message.MsgChannel.SendMessage(netMsg); + channel.SendMessage(netMsg); } private void HandlePlayerListReq(MsgPlayerListReq message) diff --git a/Robust.Shared/ContentPack/DirLoader.cs b/Robust.Shared/ContentPack/DirLoader.cs index 268588e61..a3ce21f07 100644 --- a/Robust.Shared/ContentPack/DirLoader.cs +++ b/Robust.Shared/ContentPack/DirLoader.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; +using System.Reflection; using System.Threading.Tasks; using Robust.Shared.Interfaces.Log; using Robust.Shared.Utility; @@ -125,6 +126,38 @@ namespace Robust.Shared.ContentPack } }); } + public IEnumerable GetRelativeFilePaths() + { + return GetRelativeFilePaths(_directory); + } + + private IEnumerable GetRelativeFilePaths(DirectoryInfo dir) + { + foreach (var file in dir.EnumerateFiles()) + { + if ((file.Attributes & FileAttributes.Hidden) != 0 || file.Name[0] == '.') + { + continue; + } + + var filePath = file.FullName; + var relPath = filePath.Substring(_directory.FullName.Length); + yield return ResourcePath.FromRelativeSystemPath(relPath).ToRootedPath().ToString(); + } + + foreach (var subDir in dir.EnumerateDirectories()) + { + if (((subDir.Attributes & FileAttributes.Hidden) != 0) || subDir.Name[0] == '.') + { + continue; + } + + foreach (var relPath in GetRelativeFilePaths(subDir)) + { + yield return relPath; + } + } + } } } } diff --git a/Robust.Shared/ContentPack/IContentRoot.cs b/Robust.Shared/ContentPack/IContentRoot.cs index dd0b1c786..5d94d4eaf 100644 --- a/Robust.Shared/ContentPack/IContentRoot.cs +++ b/Robust.Shared/ContentPack/IContentRoot.cs @@ -32,6 +32,12 @@ namespace Robust.Shared.ContentPack /// Directory to search inside of. /// Enumeration of all relative file paths of the files found. IEnumerable FindFiles(ResourcePath path); + + /// + /// Recursively returns relative paths to resource files. + /// + /// Enumeration of all relative file paths. + IEnumerable GetRelativeFilePaths(); } } } diff --git a/Robust.Shared/ContentPack/IModLoader.cs b/Robust.Shared/ContentPack/IModLoader.cs index c35350580..e690756ee 100644 --- a/Robust.Shared/ContentPack/IModLoader.cs +++ b/Robust.Shared/ContentPack/IModLoader.cs @@ -1,4 +1,6 @@ +using System.Collections.Generic; using System.IO; +using System.Reflection; using Robust.Shared.Interfaces.Resources; using Robust.Shared.Timing; diff --git a/Robust.Shared/ContentPack/PackLoader.cs b/Robust.Shared/ContentPack/PackLoader.cs index 09239e1c5..4f0e3bf15 100644 --- a/Robust.Shared/ContentPack/PackLoader.cs +++ b/Robust.Shared/ContentPack/PackLoader.cs @@ -90,6 +90,20 @@ namespace Robust.Shared.ContentPack } } } + + public IEnumerable GetRelativeFilePaths() + { + foreach (var entry in _zip.Entries) + { + if (entry.Name == "") + { + // Dir node. + continue; + } + + yield return new ResourcePath(entry.FullName).ToRootedPath().ToString(); + } + } } } } diff --git a/Robust.Shared/ContentPack/ResourceManager.SingleStreamLoader.cs b/Robust.Shared/ContentPack/ResourceManager.SingleStreamLoader.cs index 76d356e51..004133255 100644 --- a/Robust.Shared/ContentPack/ResourceManager.SingleStreamLoader.cs +++ b/Robust.Shared/ContentPack/ResourceManager.SingleStreamLoader.cs @@ -47,6 +47,11 @@ namespace Robust.Shared.ContentPack yield return _resourcePath; } } + + public IEnumerable GetRelativeFilePaths() + { + yield return _resourcePath.ToString(); + } } } } diff --git a/Robust.Shared/ContentPack/ResourceManager.cs b/Robust.Shared/ContentPack/ResourceManager.cs index 2afbe6016..1e79a3969 100644 --- a/Robust.Shared/ContentPack/ResourceManager.cs +++ b/Robust.Shared/ContentPack/ResourceManager.cs @@ -5,9 +5,11 @@ using System.IO; using System.Text.RegularExpressions; using System.Threading; using Robust.Shared.Interfaces.Configuration; +using Robust.Shared.Interfaces.Network; using Robust.Shared.Interfaces.Resources; using Robust.Shared.IoC; using Robust.Shared.Log; +using Robust.Shared.Serialization; using Robust.Shared.Utility; namespace Robust.Shared.ContentPack diff --git a/Robust.Shared/GameObjects/ComponentFactory.cs b/Robust.Shared/GameObjects/ComponentFactory.cs index 7f21fdd83..b90ba3c25 100644 --- a/Robust.Shared/GameObjects/ComponentFactory.cs +++ b/Robust.Shared/GameObjects/ComponentFactory.cs @@ -213,6 +213,11 @@ namespace Robust.Shared.GameObjects return _typeFactory.CreateInstance(GetRegistration(componentName).Type); } + public IComponent GetComponent(uint netId) + { + return _typeFactory.CreateInstance(GetRegistration(netId).Type); + } + public IComponentRegistration GetRegistration(string componentName) { try diff --git a/Robust.Shared/Interfaces/GameObjects/IComponentFactory.cs b/Robust.Shared/Interfaces/GameObjects/IComponentFactory.cs index ac31237ad..3cb521736 100644 --- a/Robust.Shared/Interfaces/GameObjects/IComponentFactory.cs +++ b/Robust.Shared/Interfaces/GameObjects/IComponentFactory.cs @@ -104,6 +104,13 @@ namespace Robust.Shared.Interfaces.GameObjects /// A Component IComponent GetComponent(string componentName); + /// + /// Gets a new component instantiated of the specified network ID. + /// + /// net id of component to make + /// A Component + IComponent GetComponent(uint netId); + /// /// Gets the registration belonging to a component. /// diff --git a/Robust.Shared/Interfaces/Serialization/IRobustSerializer.cs b/Robust.Shared/Interfaces/Serialization/IRobustSerializer.cs index 85ec6905d..26836e14f 100644 --- a/Robust.Shared/Interfaces/Serialization/IRobustSerializer.cs +++ b/Robust.Shared/Interfaces/Serialization/IRobustSerializer.cs @@ -1,5 +1,7 @@ using System; using System.IO; +using System.Threading.Tasks; +using Robust.Shared.Interfaces.Network; namespace Robust.Shared.Interfaces.Serialization { @@ -10,5 +12,9 @@ namespace Robust.Shared.Interfaces.Serialization T Deserialize(Stream stream); object Deserialize(Stream stream); bool CanSerialize(Type type); + + Task Handshake(INetChannel sender); + + event Action ClientHandshakeComplete; } } diff --git a/Robust.Shared/Localization/LocalizationManager.cs b/Robust.Shared/Localization/LocalizationManager.cs index 53bb4ad12..57aec23f6 100644 --- a/Robust.Shared/Localization/LocalizationManager.cs +++ b/Robust.Shared/Localization/LocalizationManager.cs @@ -7,6 +7,7 @@ using NGettext; using Robust.Shared.Interfaces.Resources; using Robust.Shared.IoC; using Robust.Shared.Localization.Macros; +using Robust.Shared.Serialization; using Robust.Shared.Utility; using YamlDotNet.RepresentationModel; @@ -163,6 +164,8 @@ namespace Robust.Shared.Localization { _readEntry(entry, catalog); } + + RobustSerializer.MappedStringSerializer.AddStrings(yamlStream, filePath.ToString()); } private static void _readEntry(YamlMappingNode entry, Catalog catalog) diff --git a/Robust.Shared/Network/NetChannel.cs b/Robust.Shared/Network/NetChannel.cs index bb132ad54..aad9717c4 100644 --- a/Robust.Shared/Network/NetChannel.cs +++ b/Robust.Shared/Network/NetChannel.cs @@ -58,6 +58,12 @@ namespace Robust.Shared.Network /// public void SendMessage(NetMessage message) { + if (_manager.IsClient) + { + _manager.ClientSendMessage(message); + return; + } + _manager.ServerSendMessage(message, this); } diff --git a/Robust.Shared/Network/NetManager.cs b/Robust.Shared/Network/NetManager.cs index ef9bc3b0e..a331c149b 100644 --- a/Robust.Shared/Network/NetManager.cs +++ b/Robust.Shared/Network/NetManager.cs @@ -14,13 +14,14 @@ using Prometheus; using Robust.Shared.Configuration; using Robust.Shared.Interfaces.Configuration; using Robust.Shared.Interfaces.Network; +using Robust.Shared.Interfaces.Serialization; using Robust.Shared.IoC; using Robust.Shared.Log; using Robust.Shared.Utility; namespace Robust.Shared.Network { - /// + /// /// Callback for registered NetMessages. /// /// The message received. @@ -37,6 +38,9 @@ namespace Robust.Shared.Network /// public partial class NetManager : IClientNetManager, IServerNetManager, IDisposable { + + [Dependency] private readonly IRobustSerializer _serializer = default!; + private static readonly Counter SentPacketsMetrics = Metrics.CreateCounter( "robust_net_sent_packets", "Number of packets sent since server startup."); @@ -226,7 +230,15 @@ namespace Robust.Shared.Network _config.RegisterCVar("net.fakelagrand", 0.0f, CVar.CHEAT, _fakeLagRandomChanged); #endif - _strings.Initialize(this, () => { OnConnected(ServerChannel!); }); + _strings.Initialize(this, () => + { + Logger.InfoS("net","Message string table loaded."); + }); + _serializer.ClientHandshakeComplete += () => + { + Logger.InfoS("net","Client completed serializer handshake."); + OnConnected(ServerChannel!); + }; _initialized = true; } @@ -621,7 +633,7 @@ namespace Robust.Shared.Network HandleInitialHandshakeComplete(connection); } - private void HandleInitialHandshakeComplete(NetConnection sender) + private async void HandleInitialHandshakeComplete(NetConnection sender) { var session = _assignedSessions[sender]; @@ -630,6 +642,8 @@ namespace Robust.Shared.Network _strings.SendFullTable(channel); + await _serializer.Handshake(channel); + Logger.InfoS("net", $"{channel.RemoteEndPoint}: Connected"); OnConnected(channel); @@ -666,13 +680,22 @@ namespace Robust.Shared.Network { var peer = msg.SenderConnection.Peer; if (peer.Status == NetPeerStatus.ShutdownRequested) + { + Logger.WarningS("net", $"{msg.SenderConnection.RemoteEndPoint}: Received data message, but shutdown is requested."); return true; + } if (peer.Status == NetPeerStatus.NotRunning) + { + Logger.WarningS("net", $"{msg.SenderConnection.RemoteEndPoint}: Received data message, peer is not running."); return true; + } if (!IsConnected) + { + Logger.WarningS("net", $"{msg.SenderConnection.RemoteEndPoint}: Received data message, but not connected."); return true; + } if (_awaitingData.TryGetValue(msg.SenderConnection, out var info)) { diff --git a/Robust.Shared/Network/StringTable.cs b/Robust.Shared/Network/StringTable.cs index e688fbd4c..5d63010ea 100644 --- a/Robust.Shared/Network/StringTable.cs +++ b/Robust.Shared/Network/StringTable.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using Lidgren.Network; using Robust.Shared.Interfaces.Network; +using Robust.Shared.Log; using Robust.Shared.Utility; namespace Robust.Shared.Network @@ -55,6 +56,7 @@ namespace Robust.Shared.Network { if (_network.IsServer) // Server does not receive entries from clients. return; + Logger.InfoS("net",$"Received message name string table."); foreach (var entry in message.Entries) { @@ -245,6 +247,7 @@ namespace Robust.Shared.Network } + Logger.InfoS("net",$"Sending message name string table to {channel.RemoteEndPoint.Address}."); _network.ServerSendMessage(message, channel); } } diff --git a/Robust.Shared/Prototypes/PrototypeManager.cs b/Robust.Shared/Prototypes/PrototypeManager.cs index be1bd5df1..d901924b5 100644 --- a/Robust.Shared/Prototypes/PrototypeManager.cs +++ b/Robust.Shared/Prototypes/PrototypeManager.cs @@ -10,6 +10,7 @@ using Robust.Shared.Interfaces.Resources; using Robust.Shared.IoC; using Robust.Shared.IoC.Exceptions; using Robust.Shared.Log; +using Robust.Shared.Serialization; using Robust.Shared.Utility; using YamlDotNet.Core; using YamlDotNet.RepresentationModel; @@ -211,7 +212,11 @@ namespace Robust.Shared.Prototypes var yamlStream = new YamlStream(); yamlStream.Load(reader); - return ((YamlStream? yamlStream, ResourcePath?))(yamlStream, filePath); + var result = ((YamlStream? yamlStream, ResourcePath?))(yamlStream, filePath); + + RobustSerializer.MappedStringSerializer.AddStrings(yamlStream, filePath.ToString()); + + return result; } catch (YamlException e) { @@ -255,6 +260,9 @@ namespace Robust.Shared.Prototypes throw new PrototypeLoadException(string.Format("Failed to load prototypes from document#{0}", i), e); } } + + + RobustSerializer.MappedStringSerializer.AddStrings(yaml, "anonymous prototypes YAML stream"); } #endregion IPrototypeManager members diff --git a/Robust.Shared/Robust.Shared.csproj b/Robust.Shared/Robust.Shared.csproj index a5e778640..bf1507b72 100644 --- a/Robust.Shared/Robust.Shared.csproj +++ b/Robust.Shared/Robust.Shared.csproj @@ -35,6 +35,21 @@ + + RobustSerializer.MappedStringSerializer.cs + + + RobustSerializer.MappedStringSerializer.cs + + + RobustSerializer.MappedStringSerializer.cs + + + RobustSerializer.cs + + + RobustSerializer.cs + diff --git a/Robust.Shared/Serialization/RobustSerializer.Handshake.cs b/Robust.Shared/Serialization/RobustSerializer.Handshake.cs new file mode 100644 index 000000000..fdc678f8e --- /dev/null +++ b/Robust.Shared/Serialization/RobustSerializer.Handshake.cs @@ -0,0 +1,35 @@ +using System; +using System.Threading.Tasks; +using Lidgren.Network; +using Robust.Shared.Interfaces.Network; + +namespace Robust.Shared.Serialization +{ + + public partial class RobustSerializer + { + /// + /// Initiates any sequence of handshake extensions that + /// need to occur before the serializer is initialized + /// for a given client. + /// + /// + /// + public Task Handshake(INetChannel channel) + => MappedStringSerializer.Handshake(channel); + + /// + /// An event that occurs once all handshake extensions have + /// completed for the client. + /// + /// Note: This should not occur on the server. + /// + public event Action ClientHandshakeComplete + { + add => MappedStringSerializer.ClientHandshakeComplete += value; + remove => MappedStringSerializer.ClientHandshakeComplete -= value; + } + + } + +} diff --git a/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgClientHandshake.cs b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgClientHandshake.cs new file mode 100644 index 000000000..374f0052d --- /dev/null +++ b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgClientHandshake.cs @@ -0,0 +1,54 @@ +using JetBrains.Annotations; +using Lidgren.Network; +using Robust.Shared.Interfaces.Network; +using Robust.Shared.Network; + +namespace Robust.Shared.Serialization +{ + + public partial class RobustSerializer + { + + public partial class MappedStringSerializer + { + /// + /// The client part of the string-exchange handshake, sent after the + /// client receives the mapping hash and after the client receives a + /// strings package. Tells the server if the client needs an updated + /// copy of the mapping. + /// + /// + /// Also sent by the client after a new copy of the string mapping + /// has been received. If successfully loaded, the value of + /// is false, otherwise it will be + /// true. + /// + /// + [UsedImplicitly] + private class MsgClientHandshake : NetMessage + { + + public MsgClientHandshake(INetChannel ch) + : base($"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgClientHandshake)}", MsgGroups.Core) + { + } + + /// + /// true if the client needs a new copy of the mapping, + /// false otherwise. + /// + public bool NeedsStrings { get; set; } + + public override void ReadFromBuffer(NetIncomingMessage buffer) + => NeedsStrings = buffer.ReadBoolean(); + + public override void WriteToBuffer(NetOutgoingMessage buffer) + => buffer.Write(NeedsStrings); + + } + + } + + } + +} diff --git a/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgServerHandshake.cs b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgServerHandshake.cs new file mode 100644 index 000000000..132adebd5 --- /dev/null +++ b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgServerHandshake.cs @@ -0,0 +1,65 @@ +using System; +using JetBrains.Annotations; +using Lidgren.Network; +using Robust.Shared.Interfaces.Network; +using Robust.Shared.Network; + +namespace Robust.Shared.Serialization +{ + + public partial class RobustSerializer + { + + public partial class MappedStringSerializer + { + + /// + /// The server part of the string-exchange handshake. Sent as the + /// first message in the handshake. Tells the client the hash of + /// the current string mapping, so the client can check if it has + /// a local copy. + /// + /// + [UsedImplicitly] + private class MsgServerHandshake : NetMessage + { + + public MsgServerHandshake(INetChannel ch) + : base($"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgServerHandshake)}", MsgGroups.Core) + { + } + + /// + /// The hash of the current string mapping held by the server. + /// + public byte[]? Hash { get; set; } + + public override void ReadFromBuffer(NetIncomingMessage buffer) + { + var len = buffer.ReadVariableInt32(); + if (len > 64) + { + throw new InvalidOperationException("Hash too long."); + } + + Hash = buffer.ReadBytes(len); + } + + public override void WriteToBuffer(NetOutgoingMessage buffer) + { + if (Hash == null) + { + throw new InvalidOperationException("Package has not been specified."); + } + + buffer.WriteVariableInt32(Hash.Length); + buffer.Write(Hash); + } + + } + + } + + } + +} diff --git a/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgStrings.cs b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgStrings.cs new file mode 100644 index 000000000..e55ef691e --- /dev/null +++ b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.MsgStrings.cs @@ -0,0 +1,72 @@ +using System; +using System.IO; +using JetBrains.Annotations; +using Lidgren.Network; +using Robust.Shared.Interfaces.Network; +using Robust.Shared.Network; + +namespace Robust.Shared.Serialization +{ + + public partial class RobustSerializer + { + + public partial class MappedStringSerializer + { + + /// + /// The meat of the string-exchange handshake sandwich. Sent by the + /// server after the client requests an updated copy of the mapping. + /// Contains the updated string mapping. + /// + /// + [UsedImplicitly] + private class MsgStrings : NetMessage + { + + public MsgStrings(INetChannel ch) + : base($"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgStrings)}", MsgGroups.Core) + { + } + + /// + /// The raw bytes of the string mapping held by the server. + /// + public byte[]? Package { get; set; } + + public override void ReadFromBuffer(NetIncomingMessage buffer) + { + var l = buffer.ReadVariableInt32(); + var success = buffer.ReadBytes(l, out var buf); + if (!success) + { + throw new InvalidDataException("Not all of the bytes were available in the message."); + } + + Package = buf; + } + + public override void WriteToBuffer(NetOutgoingMessage buffer) + { + if (Package == null) + { + throw new InvalidOperationException("Package has not been specified."); + } + + buffer.WriteVariableInt32(Package.Length); + var start = buffer.LengthBytes; + buffer.Write(Package); + var added = buffer.LengthBytes - start; + if (added != Package.Length) + { + throw new InvalidOperationException("Not all of the bytes were written to the message."); + } + } + + } + + } + + } + +} diff --git a/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.cs b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.cs new file mode 100644 index 000000000..9c7d1d1e9 --- /dev/null +++ b/Robust.Shared/Serialization/RobustSerializer.MappedStringSerializer.cs @@ -0,0 +1,1295 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.IO; +using System.IO.Compression; +using System.Linq; +using System.Reflection; +using System.Reflection.Metadata; +using System.Reflection.Metadata.Ecma335; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Security.Cryptography; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using JetBrains.Annotations; +using NetSerializer; +using Newtonsoft.Json.Linq; +using Robust.Shared.ContentPack; +using Robust.Shared.Interfaces.Log; +using Robust.Shared.Interfaces.Network; +using Robust.Shared.IoC; +using Robust.Shared.Log; +using Robust.Shared.Utility; +using YamlDotNet.RepresentationModel; + +namespace Robust.Shared.Serialization +{ + + public partial class RobustSerializer + { + + /// + /// Serializer which manages a mapping of pre-loaded strings to constant + /// values, for message compression. The mapping is shared between the + /// server and client. + /// + /// + /// Strings are long and expensive to send over the wire, and lots of + /// strings involved in messages are sent repeatedly between the server + /// and client - such as filenames, icon states, constant strings, etc. + /// + /// To compress these strings, we use a constant string mapping, decided + /// by the server when it starts up, that associates strings with a + /// fixed value. The mapping is shared with clients when they connect. + /// + /// When sending these strings over the wire, the serializer can then + /// send the constant value instead - and at the other end, the + /// serializer can use the same mapping to recover the original string. + /// + public partial class MappedStringSerializer : IStaticTypeSerializer + { + + private static INetManager? _net; + + private static readonly ISawmill LogSzr = Logger.GetSawmill("szr"); + + private static readonly HashSet IncompleteHandshakes = new HashSet(); + + /// + /// Starts the handshake from the server end of the given channel, + /// sending a . + /// + /// The network channel to perform the handshake over. + /// + /// Locks the string mapping if this is the first time the server is + /// performing the handshake. + /// + /// + /// + public static async Task Handshake(INetChannel channel) + { + var net = channel.NetPeer; + + if (net.IsClient) + { + return; + } + + if (!LockMappedStrings) + { + LockMappedStrings = true; + LogSzr.Info($"Locked in at {_MappedStrings.Count} mapped strings."); + } + + IncompleteHandshakes.Add(channel); + + var message = net.CreateNetMessage(); + message.Hash = MappedStringsHash; + net.ServerSendMessage(message, channel); + + while (IncompleteHandshakes.Contains(channel)) + { + await Task.Delay(1); + } + + LogSzr.Info($"Completed handshake with {channel.RemoteEndPoint.Address}."); + } + + /// + /// Performs the setup so that the serializer can perform the string- + /// exchange protocol. + /// + /// + /// The string-exchange protocol is started by the server when the + /// client first connects. The server sends the client a hash of the + /// string mapping; the client checks that hash against any local + /// caches; and if necessary, the client requests a new copy of the + /// mapping from the server. + /// + /// Uncached flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | + /// + /// + /// Cached flow: + /// Client | Server + /// | <-------------- Hash | + /// | Dont Need Strings -> | + /// + /// + /// Verification failure flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | + /// | <----------- Strings | + /// + Hash Failed | + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | + /// + /// + /// NOTE: Verification failure flow is currently not implemented. + /// + /// + /// The to perform the protocol steps over. + /// + /// + /// + /// + /// + /// + /// + /// + public static void NetworkInitialize(INetManager net) + { + _net = net; + + net.RegisterNetMessage( + $"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgServerHandshake)}", + msg => HandleServerHandshake(net, msg)); + + net.RegisterNetMessage( + $"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgClientHandshake)}", + msg => HandleClientHandshake(net, msg)); + + net.RegisterNetMessage( + $"{nameof(RobustSerializer)}.{nameof(MappedStringSerializer)}.{nameof(MsgStrings)}", + msg => HandleStringsMessage(net, msg)); + } + + /// + /// Handles the reception, verification of a strings package + /// and subsequent mapping of strings and initiator of + /// receipt response. + /// + /// Uncached flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | <- you are here on client + /// + /// Verification failure flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | + /// | <----------- Strings | + /// + Hash Failed | <- you are here on client + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | <- you are here on client + /// + /// + /// NOTE: Verification failure flow is currently not implemented. + /// + /// + /// Unable to verify strings package by hash. + /// + private static void HandleStringsMessage(INetManager net, MsgStrings msg) + { + if (net.IsServer) + { + LogSzr.Error("Received strings from client."); + return; + } + + LockMappedStrings = false; + ClearStrings(); + DebugTools.Assert(msg.Package != null, "msg.Package != null"); + LoadStrings(new MemoryStream(msg.Package!, false)); + var checkHash = CalculateHash(msg.Package!); + if (!checkHash.SequenceEqual(ServerHash)) + { + // TODO: retry sending MsgClientHandshake with NeedsStrings = false + throw new InvalidOperationException("Unable to verify strings package by hash." + $"\n{ConvertToBase64Url(checkHash)} vs. {ConvertToBase64Url(ServerHash)}"); + } + + _stringMapHash = ServerHash; + LockMappedStrings = true; + + LogSzr.Info($"Locked in at {_MappedStrings.Count} mapped strings."); + + WriteStringCache(); + + // ok we're good now + var channel = msg.MsgChannel; + OnClientCompleteHandshake(net, channel); + } + + /// + /// Interpret a client's handshake, either sending a package + /// of strings or completing the handshake. + /// + /// Uncached flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | <- you are here on server + /// | <----------- Strings | + /// | Dont Need Strings -> | <- you are here on server + /// + /// + /// Cached flow: + /// Client | Server + /// | <-------------- Hash | + /// | Dont Need Strings -> | <- you are here on server + /// + /// + /// Verification failure flow: + /// Client | Server + /// | <-------------- Hash | + /// | Need Strings ------> | <- you are here on server + /// | <----------- Strings | + /// + Hash Failed | + /// | Need Strings ------> | <- you are here on server + /// | <----------- Strings | + /// | Dont Need Strings -> | + /// + /// + /// NOTE: Verification failure flow is currently not implemented. + /// + /// + private static void HandleClientHandshake(INetManager net, MsgClientHandshake msg) + { + if (net.IsClient) + { + LogSzr.Error("Received client handshake on client."); + return; + } + + LogSzr.Info($"Received handshake from {msg.MsgChannel.RemoteEndPoint.Address}."); + + if (!msg.NeedsStrings) + { + LogSzr.Info($"Completing handshake with {msg.MsgChannel.RemoteEndPoint.Address}."); + IncompleteHandshakes.Remove(msg.MsgChannel); + return; + } + + // TODO: count and limit number of requests to send strings during handshake + + var strings = msg.MsgChannel.NetPeer.CreateNetMessage(); + using (var ms = new MemoryStream()) + { + WriteStringPackage(ms); + ms.Position = 0; + strings.Package = ms.ToArray(); + LogSzr.Info($"Sending {ms.Length} bytes sized mapped strings package to {msg.MsgChannel.RemoteEndPoint.Address}."); + } + + msg.MsgChannel.SendMessage(strings); + } + + /// + /// Interpret a server's handshake, either requesting a package + /// of strings or completing the handshake. + /// + /// Uncached flow: + /// Client | Server + /// | <-------------- Hash | <- you are here on client + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | + /// + /// + /// Cached flow: + /// Client | Server + /// | <-------------- Hash | <- you are here on client + /// | Dont Need Strings -> | + /// + /// + /// Verification failure flow: + /// Client | Server + /// | <-------------- Hash | <- you are here on client + /// | Need Strings ------> | + /// | <----------- Strings | + /// + Hash Failed | + /// | Need Strings ------> | + /// | <----------- Strings | + /// | Dont Need Strings -> | + /// + /// + /// NOTE: Verification failure flow is currently not implemented. + /// + /// Mapped strings are locked. + /// + private static void HandleServerHandshake(INetManager net, MsgServerHandshake msg) + { + if (net.IsServer) + { + LogSzr.Error("Received server handshake on server."); + return; + } + + ServerHash = msg.Hash; + LockMappedStrings = false; + + if (LockMappedStrings) + { + throw new InvalidOperationException("Mapped strings are locked."); + } + + ClearStrings(); + + var hashStr = ConvertToBase64Url(Convert.ToBase64String(msg.Hash!)); + + LogSzr.Info($"Received server handshake with hash {hashStr}."); + + var fileName = CacheForHash(hashStr); + if (!File.Exists(fileName)) + { + LogSzr.Info($"No string cache for {hashStr}."); + var handshake = net.CreateNetMessage(); + LogSzr.Info("Asking server to send mapped strings."); + handshake.NeedsStrings = true; + msg.MsgChannel.SendMessage(handshake); + } + else + { + LogSzr.Info($"We had a cached string map that matches {hashStr}."); + using var file = File.OpenRead(fileName); + var added = LoadStrings(file); + + _stringMapHash = msg.Hash!; + LogSzr.Info($"Read {added} strings from cache {hashStr}."); + LockMappedStrings = true; + LogSzr.Info($"Locked in at {_MappedStrings.Count} mapped strings."); + // ok we're good now + var channel = msg.MsgChannel; + OnClientCompleteHandshake(net, channel); + } + } + + /// + /// Inform the server that the client has a complete copy of the + /// mapping, and alert other code that the handshake is over. + /// + /// + /// + private static void OnClientCompleteHandshake(INetManager net, INetChannel channel) + { + LogSzr.Info("Letting server know we're good to go."); + var handshake = net.CreateNetMessage(); + handshake.NeedsStrings = false; + channel.SendMessage(handshake); + + if (ClientHandshakeComplete == null) + { + LogSzr.Warning("There's no handler attached to ClientHandshakeComplete."); + } + + ClientHandshakeComplete?.Invoke(); + } + + /// + /// Gets the cache file associated with the given hash. + /// + /// The hash to look up the cache for. + /// + /// The filename where the cache for the given hash would be. The + /// file itself may or may not exist. If it does not exist, no cache + /// was made for the given hash. + /// + private static string CacheForHash(string hashStr) + => PathHelpers.ExecutableRelativeFile($"strings-{hashStr}"); + + /// + /// Saves the string cache to a file based on it's hash. + /// + private static void WriteStringCache() + { + var hashStr = Convert.ToBase64String(MappedStringsHash); + hashStr = ConvertToBase64Url(hashStr); + + var fileName = CacheForHash(hashStr); + using var file = File.OpenWrite(fileName); + WriteStringPackage(file); + + LogSzr.Info($"Wrote string cache {hashStr}."); + } + + private static byte[]? _mappedStringsPackage; + + private static byte[] MappedStringsPackage => LockMappedStrings + ? _mappedStringsPackage ??= WriteStringPackage() + : throw new InvalidOperationException("Mapped strings must be locked."); + + /// + /// Writes strings to a package and converts to an array of bytes. + /// + /// + /// This is invoked by accessing for the first time. + /// + private static byte[] WriteStringPackage() + { + using var ms = new MemoryStream(); + WriteStringPackage(ms); + return ms.ToArray(); + } + + /// + /// Strings longer than this will throw an exception and a better strategy will need to be employed to deal with large strings. + /// + public static int StringPackageMaximumBufferSize = 65536; + + /// + /// Writes a strings package to a stream. + /// + /// A writable stream. + /// Overly long string in strings package. + public static void WriteStringPackage(Stream stream) + { + var buf = new byte[StringPackageMaximumBufferSize]; + var sw = Stopwatch.StartNew(); + var enc = Encoding.UTF8.GetEncoder(); + + using (var zs = new DeflateStream(stream, CompressionLevel.Optimal, true)) + { + var bytesWritten = WriteCompressedUnsignedInt(zs, (uint) MappedStrings.Count); + foreach (var str in MappedStrings) + { + if (str.Length >= StringPackageMaximumBufferSize) + { + throw new NotImplementedException("Overly long string in strings package."); + } + + var l = enc.GetBytes(str, buf, true); + + if (l >= StringPackageMaximumBufferSize) + { + throw new NotImplementedException("Overly long string in strings package."); + } + + bytesWritten += WriteCompressedUnsignedInt(zs, (uint) l); + + zs.Write(buf, 0, l); + + bytesWritten += l; + + enc.Reset(); + } + + zs.Write(BitConverter.GetBytes(bytesWritten)); + zs.Flush(); + } + + LogSzr.Info($"Wrote {MappedStrings.Count} strings to package in {sw.ElapsedMilliseconds}ms."); + } + + /// + /// Loads a strings package from a stream. + /// + /// + /// Uses to extract strings and adds them to the mapping. + /// + /// A readable stream. + /// The number of strings loaded. + /// Mapped strings are locked, will not load. + /// Did not read all bytes in package! + private static int LoadStrings(Stream stream) + { + if (LockMappedStrings) + { + throw new InvalidOperationException("Mapped strings are locked, will not load."); + } + + var started = MappedStrings.Count; + foreach (var str in ReadStringPackage(stream)) + { + _StringMapping[str] = _MappedStrings.Count; + _MappedStrings.Add(str); + } + + if (stream.CanSeek && stream.CanRead) + { + if (stream.Position != stream.Length) + { + throw new InvalidDataException("Did not read all bytes in package!"); + } + } + + var added = MappedStrings.Count - started; + return added; + } + + /// + /// Reads the contents of a strings package. + /// + /// + /// Does not add strings to the current mapping. + /// + /// A readable stream. + /// Strings from within the package. + /// Could not read the full length of string #N. + private static IEnumerable ReadStringPackage(Stream stream) + { + var buf = ArrayPool.Shared.Rent(65536); + var sw = Stopwatch.StartNew(); + using var zs = new DeflateStream(stream, CompressionMode.Decompress); + + var c = ReadCompressedUnsignedInt(zs, out var x); + var bytesRead = x; + for (var i = 0; i < c; ++i) + { + var l = (int) ReadCompressedUnsignedInt(zs, out x); + bytesRead += x; + var y = zs.Read(buf, 0, l); + if (y != l) + { + throw new InvalidDataException($"Could not read the full length of string #{i}."); + } + + bytesRead += y; + var str = Encoding.UTF8.GetString(buf, 0, l); + yield return str; + } + + zs.Read(buf, 0, 4); + var checkBytesRead = BitConverter.ToInt32(buf, 0); + if (checkBytesRead != bytesRead) + { + throw new InvalidDataException("Could not verify package was read correctly."); + } + + LogSzr.Info($"Read package of {c} strings in {sw.ElapsedMilliseconds}ms."); + } + + /// + /// Converts a byte array such as a hash to a Base64 representation that is URL safe. + /// + /// + /// A base64url string form of the byte array. + private static string ConvertToBase64Url(byte[]? data) + => data == null ? "" : ConvertToBase64Url(Convert.ToBase64String(data)); + + /// + /// Converts a a Base64 string to one that is URL safe. + /// + /// A base64url formed string. + private static string ConvertToBase64Url(string b64Str) + { + if (b64Str is null) + { + throw new ArgumentNullException(nameof(b64Str)); + } + + var cut = b64Str[^1] == '=' ? b64Str[^2] == '=' ? 2 : 1 : 0; + b64Str = new StringBuilder(b64Str).Replace('+', '-').Replace('/', '_').ToString(0, b64Str.Length - cut); + return b64Str; + } + + /// + /// Converts a URL-safe Base64 string into a byte array. + /// + /// A base64url formed string. + /// The represented byte array. + public static byte[] ConvertFromBase64Url(string s) + { + var l = s.Length % 3; + var sb = new StringBuilder(s); + sb.Replace('-', '+').Replace('_', '/'); + for (var i = 0; i < l; ++i) + { + sb.Append('='); + } + + s = sb.ToString(); + return Convert.FromBase64String(s); + } + + public static byte[]? ServerHash; + + private static readonly IList _MappedStrings = new List(); + + private static readonly IDictionary _StringMapping = new Dictionary(); + + public static IReadOnlyList MappedStrings => new ReadOnlyCollection(_MappedStrings); + + /// + /// Whether the string mapping is decided, and cannot be changed. + /// + /// + /// + /// While false, strings can be added to the mapping, but + /// it cannot be saved to a cache. + /// + /// + /// While true, the mapping cannot be modified, but can be + /// shared between the server and client and saved to a cache. + /// + /// + public static bool LockMappedStrings { get; set; } + + private static readonly Regex RxSymbolSplitter + = new Regex( + @"(?<=[^\s\W])(?=[A-Z]) # Match for split at start of new capital letter + |(?<=[^0-9\s\W])(?=[0-9]) # Match for split before spans of numbers + |(?<=[A-Za-z0-9])(?=_) # Match for a split before an underscore + |(?=[.\\\/,#$?!@|&*()^`""'`~[\]{}:;\-]) # Match for a split after symbols + |(?<=[.\\\/,#$?!@|&*()^`""'`~[\]{}:;\-]) # Match for a split before symbols too", + RegexOptions.CultureInvariant + | RegexOptions.Compiled + | RegexOptions.IgnorePatternWhitespace + ); + + /// + /// Add a string to the constant mapping. + /// + /// + /// If the string has multiple detectable subcomponents, such as a + /// filepath, it may result in more than one string being added to + /// the mapping. As string parts are commonly sent as subsets or + /// scoped names, this increases the likelyhood of a successful + /// string mapping. + /// + /// + /// true if the string was added to the mapping for the first + /// time, false otherwise. + /// + /// + /// Thrown if the mapping is locked, and strings cannot be added, or + /// if the string is not normalized (). + /// + public static bool AddString(string str) + { + if (LockMappedStrings) + { + if (_net.IsClient) + { + //LogSzr.Info("On client and mapped strings are locked, will not add."); + return false; + } + + throw new InvalidOperationException("Mapped strings are locked, will not add."); + } + + if (string.IsNullOrEmpty(str)) + { + return false; + } + + if (!str.IsNormalized()) + { + throw new InvalidOperationException("Only normalized strings may be added."); + } + + if (_StringMapping.ContainsKey(str)) + { + return false; + } + + if (str.Length >= MaxMappedStringSize) return false; + + if (str.Length <= MinMappedStringSize) return false; + + str = str.Trim(); + + if (str.Length <= MinMappedStringSize) return false; + + str = str.Replace(Environment.NewLine, "\n"); + + if (str.Length <= MinMappedStringSize) return false; + + var symTrimmedStr = str.Trim(TrimmableSymbolChars); + if (symTrimmedStr != str) + { + AddString(symTrimmedStr); + } + + if (str.Contains('/')) + { + var parts = str.Split('/', StringSplitOptions.RemoveEmptyEntries); + for (var i = 0; i < parts.Length; ++i) + { + for (var l = 1; l <= parts.Length - i; ++l) + { + var subStr = string.Join('/', parts.Skip(i).Take(l)); + if (_StringMapping.TryAdd(subStr, _MappedStrings.Count)) + { + _MappedStrings.Add(subStr); + } + + if (!subStr.Contains('.')) + { + continue; + } + + var subParts = subStr.Split('.', StringSplitOptions.RemoveEmptyEntries); + for (var si = 0; si < subParts.Length; ++si) + { + for (var sl = 1; sl <= subParts.Length - si; ++sl) + { + var subSubStr = string.Join('.', subParts.Skip(si).Take(sl)); + // ReSharper disable once InvertIf + if (_StringMapping.TryAdd(subSubStr, _MappedStrings.Count)) + { + _MappedStrings.Add(subSubStr); + } + } + } + } + } + } + else if (str.Contains("_")) + { + foreach (var substr in str.Split("_")) + { + AddString(substr); + } + } + else if (str.Contains(" ")) + { + foreach (var substr in str.Split(" ")) + { + if (substr == str) continue; + + AddString(substr); + } + } + else + { + var parts = RxSymbolSplitter.Split(str); + foreach (var substr in parts) + { + if (substr == str) continue; + + AddString(substr); + } + + for (var si = 0; si < parts.Length; ++si) + { + for (var sl = 1; sl <= parts.Length - si; ++sl) + { + var subSubStr = string.Concat(parts.Skip(si).Take(sl)); + if (_StringMapping.TryAdd(subSubStr, _MappedStrings.Count)) + { + _MappedStrings.Add(subSubStr); + } + } + } + } + + if (_StringMapping.TryAdd(str, _MappedStrings.Count)) + { + _MappedStrings.Add(str); + } + + _stringMapHash = null; + _mappedStringsPackage = null; + return true; + } + + /// + /// Add the constant strings from an to the + /// mapping. + /// + /// The assembly from which to collect constant strings. + /// + /// Thrown if the mapping is locked. + /// + [MethodImpl(MethodImplOptions.Synchronized)] + public static unsafe void AddStrings(Assembly asm) + { + if (LockMappedStrings) + { + if (_net.IsClient) + { + //LogSzr.Info("On client and mapped strings are locked, will not add."); + return; + } + + throw new InvalidOperationException("Mapped strings are locked, will not add ."); + } + + var started = MappedStrings.Count; + var sw = Stopwatch.StartNew(); + if (asm.TryGetRawMetadata(out var blob, out var len)) + { + var reader = new MetadataReader(blob, len); + var usrStrHandle = default(UserStringHandle); + do + { + var userStr = reader.GetUserString(usrStrHandle); + if (userStr != "") + { + AddString(string.Intern(userStr.Normalize())); + } + + usrStrHandle = reader.GetNextHandle(usrStrHandle); + } while (usrStrHandle != default); + + var strHandle = default(StringHandle); + do + { + var str = reader.GetString(strHandle); + if (str != "") + { + AddString(string.Intern(str.Normalize())); + } + + strHandle = reader.GetNextHandle(strHandle); + } while (strHandle != default); + } + + var added = MappedStrings.Count - started; + LogSzr.Info($"Mapping {added} strings from {asm.GetName().Name} took {sw.ElapsedMilliseconds}ms."); + } + + /// + /// Add strings from the given to the mapping. + /// + /// + /// Strings are taken from YAML anchors, tags, and leaf nodes. + /// + /// The YAML to collect strings from. + /// The stream name. Only used for logging. + /// + /// Thrown if the mapping is locked. + /// + [MethodImpl(MethodImplOptions.Synchronized)] + public static void AddStrings(YamlStream yaml, string name) + { + if (LockMappedStrings) + { + if (_net.IsClient) + { + //LogSzr.Info("On client and mapped strings are locked, will not add."); + return; + } + + throw new InvalidOperationException("Mapped strings are locked, will not add."); + } + + var started = MappedStrings.Count; + var sw = Stopwatch.StartNew(); + foreach (var doc in yaml) + { + foreach (var node in doc.AllNodes) + { + var a = node.Anchor; + if (!string.IsNullOrEmpty(a)) + { + AddString(a); + } + + var t = node.Tag; + if (!string.IsNullOrEmpty(t)) + { + AddString(t); + } + + switch (node) + { + case YamlScalarNode scalar: + { + var v = scalar.Value; + if (string.IsNullOrEmpty(v)) + { + continue; + } + + AddString(v); + break; + } + } + } + } + + var added = MappedStrings.Count - started; + LogSzr.Info($"Mapping {added} strings from {name} took {sw.ElapsedMilliseconds}ms."); + } + + /// + /// Add strings from the given to the mapping. + /// + /// + /// Strings are taken from JSON property names and string nodes. + /// + /// The JSON to collect strings from. + /// The stream name. Only used for logging. + /// + /// Thrown if the mapping is locked. + /// + public static void AddStrings(JObject obj, string name) + { + if (LockMappedStrings) + { + if (_net.IsClient) + { + //LogSzr.Info("On client and mapped strings are locked, will not add."); + return; + } + + throw new InvalidOperationException("Mapped strings are locked, will not add."); + } + + var started = MappedStrings.Count; + var sw = Stopwatch.StartNew(); + foreach (var node in obj.DescendantsAndSelf()) + { + switch (node) + { + case JValue value: + { + if (value.Type != JTokenType.String) + { + continue; + } + + var v = value.Value?.ToString(); + if (string.IsNullOrEmpty(v)) + { + continue; + } + + AddString(v); + break; + } + case JProperty prop: + { + var propName = prop.Name; + if (string.IsNullOrEmpty(propName)) + { + continue; + } + + AddString(propName); + break; + } + } + } + + var added = MappedStrings.Count - started; + LogSzr.Info($"Mapping {added} strings from {name} took {sw.ElapsedMilliseconds}ms."); + } + + /// + /// Remove all strings from the mapping, completely resetting it. + /// + /// + /// Thrown if the mapping is locked. + /// + public static void ClearStrings() + { + if (LockMappedStrings) + { + throw new InvalidOperationException("Mapped strings are locked, will not clear."); + } + + _MappedStrings.Clear(); + _StringMapping.Clear(); + _stringMapHash = null; + } + + /// + /// Add strings from the given enumeration to the mapping. + /// + /// The strings to add. + /// The source provider of the strings to be logged. + /// + /// Thrown if the mapping is locked. + /// + [MethodImpl(MethodImplOptions.Synchronized)] + public static void AddStrings(IEnumerable strings, string providerName) + { + if (LockMappedStrings) + { + if (_net.IsClient) + { + //LogSzr.Info("On client and mapped strings are locked, will not add."); + return; + } + + throw new InvalidOperationException("Mapped strings are locked, will not add."); + } + + var started = MappedStrings.Count; + foreach (var str in strings) + { + AddString(str); + } + + var added = MappedStrings.Count - started; + LogSzr.Info($"Mapping {added} strings from {providerName}."); + } + + private static byte[]? _stringMapHash; + + /// + /// The hash of the string mapping. + /// + /// + /// Thrown if the mapping is not locked. + /// + public static byte[] MappedStringsHash => _stringMapHash ??= CalculateMappedStringsHash(); + + private static byte[] CalculateMappedStringsHash() + { + if (!LockMappedStrings) + { + throw new InvalidOperationException("String table should be locked before attempting to retrieve hash."); + } + + var sw = Stopwatch.StartNew(); + + var hash = CalculateHash(MappedStringsPackage); + + LogSzr.Info($"Hashing {MappedStrings.Count} strings took {sw.ElapsedMilliseconds}ms."); + LogSzr.Info($"Size: {MappedStringsPackage.Length} bytes, Hash: {ConvertToBase64Url(hash)}"); + return hash; + } + + /// + /// Creates a SHA512 hash of the given array of bytes. + /// + /// An array of bytes to be hashed. + /// A 512-bit (64-byte) hash result as an array of bytes. + /// + private static byte[] CalculateHash(byte[] data) + { + if (data is null) + { + throw new ArgumentNullException(nameof(data)); + } + + using var hasher = SHA512.Create(); + var hash = hasher.ComputeHash(data); + return hash; + } + + /// + /// Implements . + /// Specifies that this implementation handles strings. + /// + public bool Handles(Type type) => type == typeof(string); + + /// + /// Implements . + /// + public IEnumerable GetSubtypes(Type type) => Type.EmptyTypes; + + /// + /// Implements . + /// + /// + public MethodInfo GetStaticWriter(Type type) => WriteMappedStringMethodInfo; + + /// + /// Implements . + /// + /// + public MethodInfo GetStaticReader(Type type) => ReadMappedStringMethodInfo; + + private delegate void WriteStringDelegate(Stream stream, string? value); + + private delegate void ReadStringDelegate(Stream stream, out string? value); + + private static readonly MethodInfo WriteMappedStringMethodInfo + = ((WriteStringDelegate) WriteMappedString).Method; + + private static readonly MethodInfo ReadMappedStringMethodInfo + = ((ReadStringDelegate) ReadMappedString).Method; + + private static readonly char[] TrimmableSymbolChars = + { + '.', '\\', '/', ',', '#', '$', '?', '!', '@', '|', '&', '*', '(', ')', '^', '`', '"', '\'', '`', '~', '[', ']', '{', '}', ':', ';', '-' + }; + + /// + /// The shortest a string can be in order to be inserted in the mapping. + /// + /// + /// Strings below a certain length aren't worth compressing. + /// + private const int MinMappedStringSize = 3; + + /// + /// The longest a string can be in order to be inserted in the mapping. + /// + private const int MaxMappedStringSize = 420; + + /// + /// The special value corresponding to a null string in the + /// encoding. + /// + private const int MappedNull = 0; + + /// + /// The special value corresponding to a string which was not mapped. + /// This is followed by the bytes of the unmapped string. + /// + private const int UnmappedString = 1; + + /// + /// The first non-special value, used for encoding mapped strings. + /// + /// + /// Since previous values are taken by and + /// , this value is used to encode + /// mapped strings at an offset - in the encoding, a value + /// >= FirstMappedIndexStart represents the string with + /// mapping of that value - FirstMappedIndexStart. + /// + private const int FirstMappedIndexStart = 2; + + /// + /// Write the encoding of the given string to the stream. + /// + /// The stream to write to. + /// The (possibly null) string to write. + public static void WriteMappedString(Stream stream, string? value) + { + if (!LockMappedStrings) + { + LogSzr.Warning("Performing unlocked string mapping."); + } + + if (value == null) + { + WriteCompressedUnsignedInt(stream, MappedNull); + return; + } + + if (_StringMapping.TryGetValue(value, out var mapping)) + { +#if DEBUG + if (mapping >= _MappedStrings.Count || mapping < 0) + { + throw new InvalidOperationException("A string mapping outside of the mapped string table was encountered."); + } +#endif + WriteCompressedUnsignedInt(stream, (uint) mapping + FirstMappedIndexStart); + //Logger.DebugS("szr", $"Encoded mapped string: {value}"); + return; + } + + // indicate not mapped + WriteCompressedUnsignedInt(stream, UnmappedString); + var buf = Encoding.UTF8.GetBytes(value); + //Logger.DebugS("szr", $"Encoded unmapped string: {value}"); + WriteCompressedUnsignedInt(stream, (uint) buf.Length); + stream.Write(buf); + } + + /// + /// Try to read a string from the given stream.. + /// + /// The stream to read from. + /// The (possibly null) string read. + /// + /// Thrown if the mapping is not locked. + /// + public static void ReadMappedString(Stream stream, out string? value) + { + if (!LockMappedStrings) + { + throw new InvalidOperationException("Not performing unlocked string mapping."); + } + + var mapIndex = ReadCompressedUnsignedInt(stream, out _); + if (mapIndex == MappedNull) + { + value = null; + return; + } + + if (mapIndex == UnmappedString) + { + // not mapped + var length = ReadCompressedUnsignedInt(stream, out _); + var buf = new byte[length]; + stream.Read(buf); + value = Encoding.UTF8.GetString(buf); + //Logger.DebugS("szr", $"Decoded unmapped string: {value}"); + return; + } + + value = _MappedStrings[(int) mapIndex - FirstMappedIndexStart]; + //Logger.DebugS("szr", $"Decoded mapped string: {value}"); + } + +#if ROBUST_SERIALIZER_DISABLE_COMPRESSED_UINTS + public static int WriteCompressedUnsignedInt(Stream stream, uint value) + { + WriteUnsignedInt(stream, value); + return 4; + } + + public static uint ReadCompressedUnsignedInt(Stream stream, out int byteCount) + { + byteCount = 4; + return ReadUnsignedInt(stream); + } +#else + public static int WriteCompressedUnsignedInt(Stream stream, uint value) + { + var length = 1; + while (value >= 0x80) + { + stream.WriteByte((byte) (0x80 | value)); + value >>= 7; + ++length; + } + + stream.WriteByte((byte) value); + return length; + } + + public static uint ReadCompressedUnsignedInt(Stream stream, out int byteCount) + { + byteCount = 0; + var value = 0u; + var shift = 0; + while (stream.CanRead) + { + var current = stream.ReadByte(); + ++byteCount; + if (current == -1) + { + throw new EndOfStreamException(); + } + + value |= (0x7Fu & (byte) current) << shift; + shift += 7; + if ((0x80 & current) == 0) + { + return value; + } + } + + throw new EndOfStreamException(); + } +#endif + + [UsedImplicitly] + public static unsafe void WriteUnsignedInt(Stream stream, uint value) + { + var bytes = MemoryMarshal.AsBytes(new ReadOnlySpan(&value, 1)); + stream.Write(bytes); + } + + [UsedImplicitly] + public static unsafe uint ReadUnsignedInt(Stream stream) + { + uint value; + var bytes = MemoryMarshal.AsBytes(new Span(&value, 1)); + stream.Read(bytes); + return value; + } + + /// + /// See . + /// + public static event Action? ClientHandshakeComplete; + + } + + } + +} diff --git a/Robust.Shared/Serialization/RobustSerializer.cs b/Robust.Shared/Serialization/RobustSerializer.cs index abc1e9c89..fffacd939 100644 --- a/Robust.Shared/Serialization/RobustSerializer.cs +++ b/Robust.Shared/Serialization/RobustSerializer.cs @@ -6,20 +6,27 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.Loader; +using Robust.Shared.Interfaces.Network; namespace Robust.Shared.Serialization { - public class RobustSerializer : IRobustSerializer + public partial class RobustSerializer : IRobustSerializer { - [Dependency] private readonly IReflectionManager reflectionManager = default!; - private Serializer Serializer = default!; + [Dependency] private readonly IReflectionManager _reflectionManager = default!; + [Dependency] private readonly INetManager _netManager = default!; - private HashSet SerializableTypes = default!; -#region Statistics + private Serializer _serializer = default!; + + private HashSet _serializableTypes = default!; + + #region Statistics + public static long LargestObjectSerializedBytes { get; private set; } + public static Type? LargestObjectSerializedType { get; private set; } public static long BytesSerialized { get; private set; } @@ -27,74 +34,119 @@ namespace Robust.Shared.Serialization public static long ObjectsSerialized { get; private set; } public static long LargestObjectDeserializedBytes { get; private set; } + public static Type? LargestObjectDeserializedType { get; private set; } public static long BytesDeserialized { get; private set; } public static long ObjectsDeserialized { get; private set; } -#endregion + + #endregion public void Initialize() { - var types = reflectionManager.FindTypesWithAttribute().ToList(); -#if DEBUG + var mappedStringSerializer = new MappedStringSerializer(); + var types = _reflectionManager.FindTypesWithAttribute().ToList(); +#if !FULL_RELEASE + // confirm only shared types are marked for serialization, no client & server only types foreach (var type in types) { - if (type.Assembly.FullName!.Contains("Server") || type.Assembly.FullName.Contains("Client")) + if (type.Assembly.FullName!.Contains("Server")) { - throw new InvalidOperationException($"Type {type} is server/client specific but has a NetSerializableAttribute!"); + throw new InvalidOperationException($"Type {type} is server specific but has a NetSerializableAttribute!"); + } + + if (type.Assembly.FullName.Contains("Client")) + { + throw new InvalidOperationException($"Type {type} is client specific but has a NetSerializableAttribute!"); } } #endif - var settings = new Settings(); - Serializer = new Serializer(types, settings); - SerializableTypes = new HashSet(Serializer.GetTypeMap().Keys); + var settings = new Settings + { + CustomTypeSerializers = new ITypeSerializer[] {mappedStringSerializer} + }; + _serializer = new Serializer(types, settings); + _serializableTypes = new HashSet(_serializer.GetTypeMap().Keys); + + if (_netManager.IsClient) + { + MappedStringSerializer.LockMappedStrings = true; + } + else + { + var defaultAssemblies = AssemblyLoadContext.Default.Assemblies; + var gameAssemblies = _reflectionManager.Assemblies; + var robustShared = defaultAssemblies + .First(a => a.GetName().Name == "Robust.Shared"); + MappedStringSerializer.AddStrings(robustShared); + + // TODO: Need to add a GetSharedAssemblies method to the reflection manager + + var contentShared = gameAssemblies + .FirstOrDefault(a => a.GetName().Name == "Content.Shared"); + if (contentShared != null) + { + MappedStringSerializer.AddStrings(contentShared); + } + + // TODO: Need to add a GetServerAssemblies method to the reflection manager + + var contentServer = gameAssemblies + .FirstOrDefault(a => a.GetName().Name == "Content.Server"); + if (contentServer != null) + { + MappedStringSerializer.AddStrings(contentServer); + } + } + + MappedStringSerializer.NetworkInitialize(_netManager); } public void Serialize(Stream stream, object toSerialize) { var start = stream.Position; - Serializer.Serialize(stream, toSerialize); + _serializer.Serialize(stream, toSerialize); var end = stream.Position; var byteCount = end - start; BytesSerialized += byteCount; ++ObjectsSerialized; - if (byteCount > LargestObjectSerializedBytes) + if (byteCount <= LargestObjectSerializedBytes) { - LargestObjectSerializedBytes = byteCount; - LargestObjectSerializedType = toSerialize.GetType(); + return; } + + LargestObjectSerializedBytes = byteCount; + LargestObjectSerializedType = toSerialize.GetType(); } public T Deserialize(Stream stream) - { - return (T) Deserialize(stream); - } + => (T) Deserialize(stream); public object Deserialize(Stream stream) { var start = stream.Position; - var result = Serializer.Deserialize(stream); + var result = _serializer.Deserialize(stream); var end = stream.Position; var byteCount = end - start; BytesDeserialized += byteCount; ++ObjectsDeserialized; - if (byteCount > LargestObjectDeserializedBytes) + if (byteCount <= LargestObjectDeserializedBytes) { - LargestObjectDeserializedBytes = byteCount; - LargestObjectDeserializedType = result.GetType(); + return result; } + LargestObjectDeserializedBytes = byteCount; + LargestObjectDeserializedType = result.GetType(); + return result; } public bool CanSerialize(Type type) - { - return SerializableTypes.Contains(type); - } + => _serializableTypes.Contains(type); } diff --git a/Robust.Shared/Serialization/YamlObjectSerializer.cs b/Robust.Shared/Serialization/YamlObjectSerializer.cs index 38adf1f00..895e761ff 100644 --- a/Robust.Shared/Serialization/YamlObjectSerializer.cs +++ b/Robust.Shared/Serialization/YamlObjectSerializer.cs @@ -511,7 +511,7 @@ namespace Robust.Shared.Serialization } else { - throw new YamlException($"Malformed type tag."); + throw new YamlException("Malformed type tag."); } } diff --git a/Robust.Shared/Utility/TypeAbbreviation.cs b/Robust.Shared/Utility/TypeAbbreviation.cs index 6be644020..719c4765d 100644 --- a/Robust.Shared/Utility/TypeAbbreviation.cs +++ b/Robust.Shared/Utility/TypeAbbreviation.cs @@ -1,6 +1,7 @@ using System; using System.IO; using System.Text; +using Robust.Shared.Serialization; using YamlDotNet.RepresentationModel; namespace Robust.Shared.Utility @@ -27,6 +28,8 @@ namespace Robust.Shared.Utility var document = yamlStream.Documents[0]; _abbreviations = ParseAbbreviations((YamlSequenceNode) document.RootNode); + + RobustSerializer.MappedStringSerializer.AddStrings(yamlStream, "(embedded) Robust.Shared.Utility.TypeAbbreviations.yaml"); } /// diff --git a/Robust.UnitTesting/Shared/GameObjects/EntityState_Tests.cs b/Robust.UnitTesting/Shared/GameObjects/EntityState_Tests.cs index 48972bd7a..3630009ed 100644 --- a/Robust.UnitTesting/Shared/GameObjects/EntityState_Tests.cs +++ b/Robust.UnitTesting/Shared/GameObjects/EntityState_Tests.cs @@ -3,11 +3,19 @@ using System.Collections.Generic; using System.IO; using NUnit.Framework; using Robust.Server.Reflection; +using Robust.Shared.Configuration; using Robust.Shared.ContentPack; using Robust.Shared.GameObjects; +using Robust.Shared.GameObjects.Components.Map; +using Robust.Shared.Interfaces.Configuration; +using Robust.Shared.Interfaces.Log; +using Robust.Shared.Interfaces.Network; using Robust.Shared.Interfaces.Reflection; using Robust.Shared.Interfaces.Serialization; using Robust.Shared.IoC; +using Robust.Shared.Log; +using Robust.Shared.Map; +using Robust.Shared.Network; using Robust.Shared.Serialization; namespace Robust.UnitTesting.Shared.GameObjects @@ -23,24 +31,42 @@ namespace Robust.UnitTesting.Shared.GameObjects public void ComponentChangedSerialized() { var container = new DependencyCollection(); + container.Register(); + container.Register(); + container.Register(); container.Register(); container.Register(); container.BuildGraph(); container.Resolve().LoadAssemblies(AppDomain.CurrentDomain.GetAssemblyByName("Robust.Shared")); + IoCManager.InitThread(container); + + container.Resolve().Initialize(true); + var serializer = container.Resolve(); serializer.Initialize(); byte[] array; using(var stream = new MemoryStream()) { - var payload = new EntityState(new EntityUid(512), Array.Empty(), Array.Empty()); + var payload = new EntityState( + new EntityUid(512), + new [] + { + new ComponentChanged(false, NetIDs.MAP_GRID, nameof(MapGridComponent)), + }, + new [] + { + new MapGridComponentState(new GridId(0)), + }); serializer.Serialize(stream, payload); array = stream.ToArray(); } + IoCManager.Clear(); + Assert.Pass($"Size in Bytes: {array.Length.ToString()}"); } }