Files
RobustToolbox/Robust.Shared/Serialization/RobustSerializer.cs
PJB3005 65b8d0cce2 Add network serialization float NaN sanitization
Apparently cheat clients have figured out that none of SS14's code does validation against NaN inputs. Uh oh.

IRobustSerializer can now be configured to remove NaN values when reading. This is intended to be set on the server to completely block the issue.

Added "Unsafe" float types that can be used to bypass the new configurable behavior, in case somebody *really* needs NaNs.

An alternative option was to make a "SafeFloat" type, and only apply the sanitization to that. The problem is that would require updating hundreds if not thousands of messages in SS14, and probably significantly confuse contributors on "when use what." Blocking NaNs by default is likely to cause little issues while ensuring the entire exploit is guaranteed impossible.
2026-01-25 03:45:50 +01:00

263 lines
8.7 KiB
C#

using NetSerializer;
using Robust.Shared.IoC;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Numerics;
using System.Reflection;
using Robust.Shared.Log;
using Robust.Shared.Maths;
using Robust.Shared.Reflection;
using Robust.Shared.Utility;
namespace Robust.Shared.Serialization
{
internal abstract partial class RobustSerializer : IRobustSerializerInternal
{
[Dependency] private readonly IReflectionManager _reflectionManager = default!;
[Dependency] protected readonly IRobustMappedStringSerializer MappedStringSerializer = default!;
[Dependency] private readonly ILogManager _logManager = default!;
private readonly Dictionary<Type, Dictionary<string, Type?>> _cachedSerialized = new();
private ISawmill LogSzr = default!;
private Serializer _serializer = default!;
private HashSet<Type> _serializableTypes = default!;
private bool _initialized;
private SerializerFloatFlags _floatFlags;
private static Type[] AlwaysNetSerializable => new[]
{
typeof(Vector2i)
};
#region Statistics
private readonly object _statsLock = new();
// These stats aren't tracked correctly because the tracking code isn't thread safe. Oops!
public long LargestObjectSerializedBytes { get; private set; }
public Type? LargestObjectSerializedType { get; private set; }
public long BytesSerialized { get; private set; }
public long ObjectsSerialized { get; private set; }
public long LargestObjectDeserializedBytes { get; private set; }
public Type? LargestObjectDeserializedType { get; private set; }
public long BytesDeserialized { get; private set; }
public long ObjectsDeserialized { get; private set; }
#endregion
public SerializerFloatFlags FloatFlags
{
get => _floatFlags;
set
{
if (_initialized)
throw new InvalidOperationException("Already initialized!");
_floatFlags = value;
}
}
public void Initialize()
{
if (_initialized)
throw new InvalidOperationException("Already initialized!");
_initialized = true;
var types = _reflectionManager.FindTypesWithAttribute<NetSerializableAttribute>()
.OrderBy(x => x.FullName, StringComparer.InvariantCulture)
.ToList();
#if DEBUG
// confirm only shared types are marked for serialization, no client & server only types
foreach (var type in types)
{
if (type.Assembly.FullName!.Contains("Server"))
{
throw new InvalidOperationException($"Type {type} is server specific but has a NetSerializableAttribute!");
}
if (type.Assembly.FullName.Contains("Client"))
{
throw new InvalidOperationException($"Type {type} is client specific but has a NetSerializableAttribute!");
}
}
#endif
LogSzr = _logManager.GetSawmill("szr");
types.AddRange(AlwaysNetSerializable);
types.Add(typeof(Vector2));
MappedStringSerializer.Initialize();
var settings = new Settings
{
CustomTypeSerializers = new[]
{
MappedStringSerializer.TypeSerializer,
new NetMathSerializer(),
new NetBitArraySerializer(),
new NetFormattedStringSerializer(),
new NetUnsafeFloatSerializer(),
}
};
if ((_floatFlags & SerializerFloatFlags.RemoveReadNan) != 0)
{
settings.CustomTypeSerializers =
[
..settings.CustomTypeSerializers,
// This replaces NetSerializer's default serializer.
new NetSafeFloatSerializer()
];
}
_serializer = new Serializer(types, settings);
_serializableTypes = new HashSet<Type>(_serializer.GetTypeMap().Keys);
LogSzr.Info($"Serializer Types Hash: {_serializer.GetSHA256()}");
}
public byte[] GetSerializableTypesHash() => Convert.FromHexString(_serializer.GetSHA256());
public string GetSerializableTypesHashString() => _serializer.GetSHA256();
internal void GetHashManifest(Stream stream, bool writeNewline=false)
{
_serializer.GetHashManifest(stream, writeNewline);
}
public (byte[] Hash, byte[] Package) GetStringSerializerPackage() => MappedStringSerializer.GeneratePackage();
public Dictionary<Type, uint> GetTypeMap() => _serializer.GetTypeMap();
public void Serialize(Stream stream, object toSerialize)
{
var start = StartMeasureStats(stream);
_serializer.Serialize(stream, toSerialize);
EndMeasureSerialize(stream, start, toSerialize.GetType());
}
public void SerializeDirect<T>(Stream stream, T toSerialize)
{
DebugTools.Assert(toSerialize == null || typeof(T) == toSerialize.GetType(),
"Object must be of exact type specified in the generic parameter.");
var start = StartMeasureStats(stream);
_serializer.SerializeDirect(stream, toSerialize);
EndMeasureSerialize(stream, start, typeof(T));
}
public T Deserialize<T>(Stream stream)
=> (T) Deserialize(stream);
public void DeserializeDirect<T>(Stream stream, out T value)
{
var start = StartMeasureStats(stream);
_serializer.DeserializeDirect(stream, out value);
EndMeasureDeserialize(stream, start, typeof(T));
}
public object Deserialize(Stream stream)
{
var start = StartMeasureStats(stream);
var result = _serializer.Deserialize(stream);
EndMeasureDeserialize(stream, start, result.GetType());
return result;
}
public bool CanSerialize(Type type)
=> _serializableTypes.Contains(type);
/// <inheritdoc />
public Type? FindSerializedType(Type assignableType, string serializedTypeName)
{
if (!_cachedSerialized.TryGetValue(assignableType, out var assigned))
{
assigned = new Dictionary<string, Type?>();
_cachedSerialized[assignableType] = assigned;
}
if (assigned.TryGetValue(serializedTypeName, out var resolved))
return resolved;
var types = _reflectionManager.GetAllChildren(assignableType);
foreach (var type in types)
{
var serializedAttribute = type.GetCustomAttribute<SerializedTypeAttribute>();
if(serializedAttribute is null)
continue;
if (serializedAttribute.SerializeName == serializedTypeName)
{
assigned[serializedTypeName] = type;
return type;
}
}
assigned[serializedTypeName] = null;
return null;
}
private static long StartMeasureStats(Stream stream)
{
return stream.CanSeek ? stream.Position : 0;
}
private void EndMeasureDeserialize(Stream stream, long start, Type type)
{
lock (_statsLock)
{
ObjectsDeserialized += 1;
if (stream.CanSeek)
{
var end = stream.Position;
var byteCount = end - start;
BytesDeserialized += byteCount;
if (byteCount > LargestObjectDeserializedBytes)
{
LargestObjectDeserializedBytes = byteCount;
LargestObjectDeserializedType = type;
}
}
}
}
private void EndMeasureSerialize(Stream stream, long start, Type type)
{
lock (_statsLock)
{
ObjectsSerialized += 1;
if (stream.CanSeek)
{
var end = stream.Position;
var byteCount = end - start;
BytesSerialized += byteCount;
if (byteCount > LargestObjectSerializedBytes)
{
LargestObjectSerializedBytes = byteCount;
LargestObjectSerializedType = type;
}
}
}
}
}
}