diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index d50bc820e..b951cd4b0 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -39,7 +39,10 @@ END TEMPLATE--> ### New features -*None yet* +* `IRobustSerializer` can now be configured to remove float NaN values when reading. + * This is intended to blanket block cheat clients from sending NaN values in input commands they shouldn't. + * To enable, set `IRobustSerializer.FloatFlags` from your content entrypoint. + * If you do really want to send NaN values while using the above, you can use the new `UnsafeFloat`, `UnsafeHalf`, and `UnsafeDouble` types to indicate a field that is exempt. ### Bugfixes diff --git a/Robust.Shared.IntegrationTests/Serialization/NetSerializerDefaultFloatTest.cs b/Robust.Shared.IntegrationTests/Serialization/NetSerializerDefaultFloatTest.cs new file mode 100644 index 000000000..787a9c8e1 --- /dev/null +++ b/Robust.Shared.IntegrationTests/Serialization/NetSerializerDefaultFloatTest.cs @@ -0,0 +1,203 @@ +using JetBrains.Annotations; +using NUnit.Framework; +using Robust.Shared.IoC; +using Robust.Shared.Maths; +using Robust.Shared.Serialization; +using Robust.UnitTesting.Shared; + +namespace Robust.Shared.IntegrationTests.Serialization; + +[Serializable, NetSerializable] +[UsedImplicitly(Reason = "Needed so RobustSerializer is guaranteed to pick up on the unsafe types.")] +internal sealed class MakeTheseSerializable +{ + public UnsafeFloat Single; + public UnsafeDouble Double; + public UnsafeHalf Half; + public Half SafeHalf; +} + +/// +/// Tests the serialization behavior of float types when is *not* set to do anything special. +/// Tests both primitives and Robust's "Unsafe" variants. +/// +[TestFixture, TestOf(typeof(RobustSerializer)), TestOf(typeof(NetUnsafeFloatSerializer))] +internal sealed class NetSerializerDefaultFloatTest : OurRobustUnitTest +{ + private IRobustSerializer _serializer = null!; + + [OneTimeSetUp] + public void Setup() + { + _serializer = IoCManager.Resolve(); + _serializer.Initialize(); + } + + internal static readonly TestCaseData[] PassThroughFloatTests = + [ + new TestCaseData(0.0).Returns(0.0), + new TestCaseData(1.0).Returns(1.0), + new TestCaseData(double.NaN).Returns(double.NaN), + new TestCaseData(double.PositiveInfinity).Returns(double.PositiveInfinity), + ]; + + [TestCaseSource(nameof(PassThroughFloatTests))] + public double TestSingle(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, (float)input); + + ms.Position = 0; + + return _serializer.Deserialize(ms); + } + + [TestCaseSource(nameof(PassThroughFloatTests))] + public double TestUnsafeSingle(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, (UnsafeFloat)input); + + ms.Position = 0; + + return _serializer.Deserialize(ms); + } + + [TestCaseSource(nameof(PassThroughFloatTests))] + public double TestDouble(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, input); + + ms.Position = 0; + + return _serializer.Deserialize(ms); + } + + [TestCaseSource(nameof(PassThroughFloatTests))] + public double TestUnsafeDouble(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, (UnsafeDouble)input); + + ms.Position = 0; + + return _serializer.Deserialize(ms); + } + + [TestCaseSource(nameof(PassThroughFloatTests))] + public double TestHalf(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, (Half)input); + + ms.Position = 0; + + return (double)_serializer.Deserialize(ms); + } + + [TestCaseSource(nameof(PassThroughFloatTests))] + public double TestUnsafeHalf(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, (UnsafeHalf)(Half)input); + + ms.Position = 0; + + return (double)(Half)_serializer.Deserialize(ms); + } +} + +/// +/// Tests the serialization behavior of float types when is set to remove NaNs on read. +/// Tests both primitives and Robust's "Unsafe" variants. +/// +[TestFixture] +[TestOf(typeof(RobustSerializer)), TestOf(typeof(NetUnsafeFloatSerializer)), TestOf(typeof(NetSafeFloatSerializer))] +internal sealed class NetSerializerSafeFloatTest : OurRobustUnitTest +{ + private IRobustSerializer _serializer = default!; + + [OneTimeSetUp] + public void Setup() + { + _serializer = IoCManager.Resolve(); + _serializer.FloatFlags = SerializerFloatFlags.RemoveReadNan; + _serializer.Initialize(); + } + + internal static readonly TestCaseData[] SafeFloatTests = + [ + new TestCaseData(0.0).Returns(0.0), + new TestCaseData(1.0).Returns(1.0), + new TestCaseData(double.NaN).Returns(0.0), + new TestCaseData(double.PositiveInfinity).Returns(double.PositiveInfinity), + ]; + + [TestCaseSource(nameof(SafeFloatTests))] + public double TestSingle(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, (float)input); + + ms.Position = 0; + + return _serializer.Deserialize(ms); + } + + [TestCaseSource(typeof(NetSerializerDefaultFloatTest), nameof(NetSerializerDefaultFloatTest.PassThroughFloatTests))] + public double TestUnsafeSingle(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, (UnsafeFloat)input); + + ms.Position = 0; + + return _serializer.Deserialize(ms); + } + + [TestCaseSource(nameof(SafeFloatTests))] + public double TestDouble(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, input); + + ms.Position = 0; + + return _serializer.Deserialize(ms); + } + + [TestCaseSource(typeof(NetSerializerDefaultFloatTest), nameof(NetSerializerDefaultFloatTest.PassThroughFloatTests))] + public double TestUnsafeDouble(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, (UnsafeDouble)input); + + ms.Position = 0; + + return _serializer.Deserialize(ms); + } + + + [TestCaseSource(nameof(SafeFloatTests))] + public double TestHalf(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, (Half)input); + + ms.Position = 0; + + return (double)_serializer.Deserialize(ms); + } + + [TestCaseSource(typeof(NetSerializerDefaultFloatTest), nameof(NetSerializerDefaultFloatTest.PassThroughFloatTests))] + public double TestUnsafeHalf(double input) + { + var ms = new MemoryStream(); + _serializer.Serialize(ms, (UnsafeHalf)(Half)input); + + ms.Position = 0; + + return (double)(Half)_serializer.Deserialize(ms); + } +} diff --git a/Robust.Shared.Maths/UnsafeFloat.cs b/Robust.Shared.Maths/UnsafeFloat.cs new file mode 100644 index 000000000..ecf985ee0 --- /dev/null +++ b/Robust.Shared.Maths/UnsafeFloat.cs @@ -0,0 +1,53 @@ +using System; + +namespace Robust.Shared.Maths; + +/// +/// Marker type to indicate floating point values that should preserve NaNs across the network. +/// +/// +/// Robust's network serializer may be configured to flush NaN float values to 0, +/// to avoid exploits from lacking input validation. Even if this feature is enabled, +/// NaN values passed in this type are still untouched. +/// +/// The actual inner floating point value +/// +public readonly record struct UnsafeHalf(Half Value) +{ + public static implicit operator Half(UnsafeHalf f) => f.Value; + public static implicit operator UnsafeHalf(Half f) => new(f); +} + +/// +/// Marker type to indicate floating point values that should preserve NaNs across the network. +/// +/// +/// Robust's network serializer may be configured to flush NaN float values to 0, +/// to avoid exploits from lacking input validation. Even if this feature is enabled, +/// NaN values passed in this type are still untouched. +/// +/// The actual inner floating point value +/// +public readonly record struct UnsafeFloat(float Value) +{ + public static implicit operator float(UnsafeFloat f) => f.Value; + public static implicit operator UnsafeFloat(float f) => new(f); +} + +/// +/// Marker type to indicate floating point values that should preserve NaNs across the network. +/// +/// +/// Robust's network serializer may be configured to flush NaN float values to 0, +/// to avoid exploits from lacking input validation. Even if this feature is enabled, +/// NaN values passed in this type are still untouched. +/// +/// The actual inner floating point value +/// +public readonly record struct UnsafeDouble(double Value) +{ + public static implicit operator double(UnsafeDouble f) => f.Value; + public static implicit operator UnsafeDouble(double f) => new(f); + public static implicit operator UnsafeDouble(float f) => new(f); + public static implicit operator UnsafeDouble(UnsafeFloat f) => new(f); +} diff --git a/Robust.Shared/Serialization/IRobustSerializer.cs b/Robust.Shared/Serialization/IRobustSerializer.cs index ed1039f51..ea5536439 100644 --- a/Robust.Shared/Serialization/IRobustSerializer.cs +++ b/Robust.Shared/Serialization/IRobustSerializer.cs @@ -2,6 +2,8 @@ using System; using System.Collections.Generic; using System.IO; using System.Threading.Tasks; +using Robust.Shared.ContentPack; +using Robust.Shared.Maths; using Robust.Shared.Network; namespace Robust.Shared.Serialization @@ -9,6 +11,21 @@ namespace Robust.Shared.Serialization [NotContentImplementable] public interface IRobustSerializer { + /// + /// Specifies how the serializer should handle read floating point values. + /// + /// + /// Both sides of the network need not have the same float handling flags. + /// + /// + /// Thrown if set after the serializer has already been initialized. + /// (must be done from ) + /// + SerializerFloatFlags FloatFlags { get; set; } + + /// + /// Thrown if called twice. + /// void Initialize(); void Serialize(Stream stream, object toSerialize); @@ -70,4 +87,25 @@ namespace Robust.Shared.Serialization long BytesDeserialized { get; } long ObjectsDeserialized { get; } } + + /// + /// Flags for float handling. + /// + /// + /// These flags have no effect on values passed in a , or + /// . + /// + [Flags] + public enum SerializerFloatFlags + { + /// + /// No special behavior: floating point values are read exactly as sent over the network. + /// + None = 0, + + /// + /// Read NaN values will be cleared to zero. + /// + RemoveReadNan = 1 << 0, + } } diff --git a/Robust.Shared/Serialization/NetSafeFloatSerializer.cs b/Robust.Shared/Serialization/NetSafeFloatSerializer.cs new file mode 100644 index 000000000..d60f34966 --- /dev/null +++ b/Robust.Shared/Serialization/NetSafeFloatSerializer.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Reflection; +using NetSerializer; + +namespace Robust.Shared.Serialization; + +/// +/// Replaces NetSerializer's default float handling to read NaN values as 0. +/// +internal sealed class NetSafeFloatSerializer : IStaticTypeSerializer +{ + public bool Handles(Type type) + { + return type == typeof(float) || type == typeof(double) || type == typeof(Half); + } + + public IEnumerable GetSubtypes(Type type) + { + return []; + } + + public MethodInfo GetStaticWriter(Type type) + { + return typeof(Primitives).GetMethod(nameof(Primitives.WritePrimitive), + BindingFlags.Public | BindingFlags.Static, + [typeof(Stream), type])!; + } + + public MethodInfo GetStaticReader(Type type) + { + return typeof(SafePrimitives).GetMethod(nameof(SafePrimitives.ReadPrimitive), + BindingFlags.Public | BindingFlags.Static, + [typeof(Stream), type.MakeByRefType()])!; + } +} diff --git a/Robust.Shared/Serialization/NetUnsafeFloatSerializer.cs b/Robust.Shared/Serialization/NetUnsafeFloatSerializer.cs new file mode 100644 index 000000000..d3a94ff0f --- /dev/null +++ b/Robust.Shared/Serialization/NetUnsafeFloatSerializer.cs @@ -0,0 +1,78 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Reflection; +using JetBrains.Annotations; +using NetSerializer; +using Robust.Shared.Maths; + +namespace Robust.Shared.Serialization; + +/// +/// NetSerializer type serializer for , , and . +/// +internal sealed class NetUnsafeFloatSerializer : IStaticTypeSerializer +{ + public bool Handles(Type type) + { + return type == typeof(UnsafeFloat) || type == typeof(UnsafeDouble) || type == typeof(UnsafeHalf); + } + + public IEnumerable GetSubtypes(Type type) + { + return []; + } + + public MethodInfo GetStaticWriter(Type type) + { + return typeof(NetUnsafeFloatSerializer).GetMethod(nameof(Write), + BindingFlags.NonPublic | BindingFlags.Static, + [typeof(Stream), type])!; + } + + public MethodInfo GetStaticReader(Type type) + { + return typeof(NetUnsafeFloatSerializer).GetMethod(nameof(Read), + BindingFlags.NonPublic | BindingFlags.Static, + [typeof(Stream), type.MakeByRefType()])!; + } + + [UsedImplicitly] + private static void Write(Stream stream, UnsafeFloat value) + { + Primitives.WritePrimitive(stream, value); + } + + [UsedImplicitly] + private static void Read(Stream stream, out UnsafeFloat value) + { + Primitives.ReadPrimitive(stream, out float readValue); + value = readValue; + } + + [UsedImplicitly] + private static void Write(Stream stream, UnsafeDouble value) + { + Primitives.WritePrimitive(stream, value); + } + + [UsedImplicitly] + private static void Read(Stream stream, out UnsafeDouble value) + { + Primitives.ReadPrimitive(stream, out double readValue); + value = readValue; + } + + [UsedImplicitly] + private static void Write(Stream stream, UnsafeHalf value) + { + Primitives.WritePrimitive(stream, value); + } + + [UsedImplicitly] + private static void Read(Stream stream, out UnsafeHalf value) + { + Primitives.ReadPrimitive(stream, out Half readValue); + value = readValue; + } +} diff --git a/Robust.Shared/Serialization/RobustSerializer.cs b/Robust.Shared/Serialization/RobustSerializer.cs index ba8cf9af5..cff7ae776 100644 --- a/Robust.Shared/Serialization/RobustSerializer.cs +++ b/Robust.Shared/Serialization/RobustSerializer.cs @@ -27,6 +27,8 @@ namespace Robust.Shared.Serialization private Serializer _serializer = default!; private HashSet _serializableTypes = default!; + private bool _initialized; + private SerializerFloatFlags _floatFlags; private static Type[] AlwaysNetSerializable => new[] { @@ -56,8 +58,25 @@ namespace Robust.Shared.Serialization #endregion + public SerializerFloatFlags FloatFlags + { + get => _floatFlags; + set + { + if (_initialized) + throw new InvalidOperationException("Already initialized!"); + + _floatFlags = value; + } + } + public void Initialize() { + if (_initialized) + throw new InvalidOperationException("Already initialized!"); + + _initialized = true; + var types = _reflectionManager.FindTypesWithAttribute() .OrderBy(x => x.FullName, StringComparer.InvariantCulture) .ToList(); @@ -91,9 +110,21 @@ namespace Robust.Shared.Serialization MappedStringSerializer.TypeSerializer, new NetMathSerializer(), new NetBitArraySerializer(), - new NetFormattedStringSerializer() + new NetFormattedStringSerializer(), + new NetUnsafeFloatSerializer(), } }; + + if ((_floatFlags & SerializerFloatFlags.RemoveReadNan) != 0) + { + settings.CustomTypeSerializers = + [ + ..settings.CustomTypeSerializers, + // This replaces NetSerializer's default serializer. + new NetSafeFloatSerializer() + ]; + } + _serializer = new Serializer(types, settings); _serializableTypes = new HashSet(_serializer.GetTypeMap().Keys); LogSzr.Info($"Serializer Types Hash: {_serializer.GetSHA256()}"); diff --git a/Robust.Shared/Serialization/SafePrimitives.cs b/Robust.Shared/Serialization/SafePrimitives.cs new file mode 100644 index 000000000..67bc19cfa --- /dev/null +++ b/Robust.Shared/Serialization/SafePrimitives.cs @@ -0,0 +1,45 @@ +using System; +using System.IO; +using JetBrains.Annotations; +using NetSerializer; + +namespace Robust.Shared.Serialization; + +/// +/// "Safer" read primitives as an alternative to . +/// +internal static class SafePrimitives +{ + /// + /// Read a float value from the stream, flushing NaNs to zero. + /// + [UsedImplicitly] + public static void ReadPrimitive(Stream stream, out float value) + { + Primitives.ReadPrimitive(stream, out float readFloat); + + value = float.IsNaN(readFloat) ? 0 : readFloat; + } + + /// + /// Read a double value from the stream, flushing NaNs to zero. + /// + [UsedImplicitly] + public static void ReadPrimitive(Stream stream, out double value) + { + Primitives.ReadPrimitive(stream, out double readDouble); + + value = double.IsNaN(readDouble) ? 0 : readDouble; + } + + /// + /// Read a double value from the stream, flushing NaNs to zero. + /// + [UsedImplicitly] + public static void ReadPrimitive(Stream stream, out Half value) + { + Primitives.ReadPrimitive(stream, out Half readDouble); + + value = Half.IsNaN(readDouble) ? Half.Zero : readDouble; + } +} diff --git a/Robust.Shared/Serialization/TypeSerializers/Implementations/Primitive/UnsafeFloatSerializer.cs b/Robust.Shared/Serialization/TypeSerializers/Implementations/Primitive/UnsafeFloatSerializer.cs new file mode 100644 index 000000000..90222329a --- /dev/null +++ b/Robust.Shared/Serialization/TypeSerializers/Implementations/Primitive/UnsafeFloatSerializer.cs @@ -0,0 +1,103 @@ +using Robust.Shared.IoC; +using Robust.Shared.Maths; +using Robust.Shared.Serialization.Manager; +using Robust.Shared.Serialization.Manager.Attributes; +using Robust.Shared.Serialization.Markdown; +using Robust.Shared.Serialization.Markdown.Validation; +using Robust.Shared.Serialization.Markdown.Value; +using Robust.Shared.Serialization.TypeSerializers.Interfaces; + +namespace Robust.Shared.Serialization.TypeSerializers.Implementations.Primitive; + +/// +/// Implementation of type serializers for and . +/// +/// +/// These don't need to do anything different from and , +/// because YAML cannot contain NaNs. +/// +[TypeSerializer] +internal sealed class UnsafeFloatSerializer : + ITypeSerializer, ITypeCopyCreator, + ITypeSerializer, ITypeCopyCreator +{ + ValidationNode ITypeValidator.Validate( + ISerializationManager serializationManager, + ValueDataNode node, + IDependencyCollection dependencies, + ISerializationContext? context) + { + return serializationManager.ValidateNode(node, context); + } + + public UnsafeFloat Read( + ISerializationManager serializationManager, + ValueDataNode node, + IDependencyCollection dependencies, + SerializationHookContext hookCtx, + ISerializationContext? context = null, + ISerializationManager.InstantiationDelegate? instanceProvider = null) + { + return serializationManager.Read(node, hookCtx, context); + } + + public DataNode Write( + ISerializationManager serializationManager, + UnsafeFloat value, + IDependencyCollection dependencies, + bool alwaysWrite = false, + ISerializationContext? context = null) + { + return serializationManager.WriteValue(value.Value, alwaysWrite, context); + } + + ValidationNode ITypeValidator.Validate( + ISerializationManager serializationManager, + ValueDataNode node, + IDependencyCollection dependencies, + ISerializationContext? context) + { + return serializationManager.ValidateNode(node, context); + } + + public UnsafeDouble Read( + ISerializationManager serializationManager, + ValueDataNode node, + IDependencyCollection dependencies, + SerializationHookContext hookCtx, + ISerializationContext? context = null, + ISerializationManager.InstantiationDelegate? instanceProvider = null) + { + return serializationManager.Read(node, hookCtx, context); + } + + public DataNode Write( + ISerializationManager serializationManager, + UnsafeDouble value, + IDependencyCollection dependencies, + bool alwaysWrite = false, + ISerializationContext? context = null) + { + return serializationManager.WriteValue(value.Value, alwaysWrite, context); + } + + public UnsafeFloat CreateCopy( + ISerializationManager serializationManager, + UnsafeFloat source, + IDependencyCollection dependencies, + SerializationHookContext hookCtx, + ISerializationContext? context = null) + { + return source; + } + + public UnsafeDouble CreateCopy( + ISerializationManager serializationManager, + UnsafeDouble source, + IDependencyCollection dependencies, + SerializationHookContext hookCtx, + ISerializationContext? context = null) + { + return source; + } +}