mirror of
https://github.com/space-wizards/RobustToolbox.git
synced 2026-02-14 19:29:36 +01:00
* WebSocket-based data transfer system * Move resource downloads/uploads to the new transfer system Should drastically increase the permitted practical size * Transfer impl for Lidgren * Async impl for receive stream * Use unbounded channel for Lidgren * Add metrics * More comments * Add serverside stream limit to avoid being a DoS vector * Fix tests * Oops forgot to actually implement sequence channels in NetMessage * Doc comment for NetMessage.SequenceChannel * Release notes
437 lines
14 KiB
C#
437 lines
14 KiB
C#
using System;
|
|
using System.Buffers;
|
|
using System.Buffers.Binary;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
using System.Net;
|
|
using System.Text;
|
|
using System.Threading;
|
|
using System.Threading.Channels;
|
|
using System.Threading.Tasks;
|
|
using Robust.Shared.Log;
|
|
using Robust.Shared.Utility;
|
|
|
|
namespace Robust.Shared.Network.Transfer;
|
|
|
|
internal abstract class BaseTransferImpl(ISawmill sawmill, BaseTransferManager parent, INetChannel channel) : IDisposable
|
|
{
|
|
// Custom framing format is as follows.
|
|
// <header message>
|
|
// uint8 opcode
|
|
// uint8 flags
|
|
// int64 transfer ID
|
|
// [if start message]:
|
|
// uint8 key length
|
|
// byte[] key
|
|
// <data message>
|
|
// just the fucking data lol
|
|
|
|
internal const int BufferSize = 16384;
|
|
internal const int MaxKeySize = 96;
|
|
internal const int MaxHeaderSize = 128;
|
|
|
|
protected readonly INetChannel Channel = channel;
|
|
protected readonly ISawmill Sawmill = sawmill;
|
|
|
|
protected long OutgoingIdCounter;
|
|
public int MaxChannelCount = int.MaxValue;
|
|
|
|
private readonly Dictionary<long, ChannelWriter<ArraySegment<byte>>> _receivingChannels = [];
|
|
|
|
private readonly SemaphoreSlim _socketSemaphore = new(1, 1);
|
|
internal readonly BaseTransferManager Parent = parent;
|
|
|
|
public abstract Task ServerInit();
|
|
public abstract Task ClientInit(CancellationToken cancel);
|
|
public abstract Stream StartTransfer(TransferStartInfo startInfo);
|
|
|
|
protected abstract bool BoundedChannel { get; }
|
|
|
|
private void TransferReceived(string key, ChannelReader<ArraySegment<byte>> reader)
|
|
{
|
|
if (_receivingChannels.Count >= MaxChannelCount)
|
|
{
|
|
Sawmill.Warning($"Disconnecting client {Channel} for breaching max channel count of {_receivingChannels}");
|
|
Channel.Disconnect("Reached max transfer channel count");
|
|
return;
|
|
}
|
|
|
|
// var stream = new ReceiveStream(reader);
|
|
// Parent.TransferReceived(key, Channel, stream);
|
|
}
|
|
|
|
protected void HandleHeaderReceived(
|
|
ReadOnlyMemory<byte> data,
|
|
out TransferFlags flags,
|
|
out long transferId,
|
|
out ChannelWriter<ArraySegment<byte>> channel)
|
|
{
|
|
ParseHeader(data.Span, out flags, out transferId, out var key);
|
|
|
|
if (!_receivingChannels.TryGetValue(transferId, out channel!))
|
|
{
|
|
if ((flags & TransferFlags.Start) == 0)
|
|
throw new ProtocolViolationException($"Received data for unknown transfer {transferId}");
|
|
|
|
DebugTools.Assert(key != null);
|
|
|
|
Sawmill.Verbose($"Starting transfer stream {transferId} with key {key}");
|
|
|
|
var fullChannel = BoundedChannel
|
|
? System.Threading.Channels.Channel.CreateBounded<ArraySegment<byte>>(
|
|
new BoundedChannelOptions(4)
|
|
{
|
|
SingleReader = true,
|
|
SingleWriter = true
|
|
})
|
|
: System.Threading.Channels.Channel.CreateUnbounded<ArraySegment<byte>>(new UnboundedChannelOptions
|
|
{
|
|
SingleReader = true,
|
|
SingleWriter = true
|
|
});
|
|
|
|
channel = fullChannel.Writer;
|
|
_receivingChannels.Add(transferId, channel);
|
|
|
|
TransferReceived(key, fullChannel.Reader);
|
|
}
|
|
}
|
|
|
|
protected void HandlePostData(TransferFlags flags, long transferId, ChannelWriter<ArraySegment<byte>> channel)
|
|
{
|
|
if ((flags & TransferFlags.Finish) != 0)
|
|
{
|
|
Sawmill.Verbose($"Finishing transfer stream {transferId}");
|
|
|
|
channel.Complete();
|
|
_receivingChannels.Remove(transferId);
|
|
}
|
|
}
|
|
|
|
private static void ParseHeader(
|
|
ReadOnlySpan<byte> buf,
|
|
out TransferFlags flags,
|
|
out long transferId,
|
|
out string? key)
|
|
{
|
|
flags = (TransferFlags)buf[1];
|
|
transferId = BinaryPrimitives.ReadInt64LittleEndian(buf[2..10]);
|
|
|
|
if ((flags & TransferFlags.Start) != 0)
|
|
{
|
|
var keyLength = buf[10];
|
|
key = Encoding.UTF8.GetString(buf.Slice(11, keyLength));
|
|
}
|
|
else
|
|
{
|
|
key = null;
|
|
}
|
|
}
|
|
|
|
private sealed class ReceiveStream : SaneStream
|
|
{
|
|
private readonly ChannelReader<ArraySegment<byte>> _bufferChannel;
|
|
|
|
private ArraySegment<byte> _currentBuffer;
|
|
|
|
public override bool CanRead => true;
|
|
|
|
public ReceiveStream(ChannelReader<ArraySegment<byte>> bufferChannel)
|
|
{
|
|
_bufferChannel = bufferChannel;
|
|
}
|
|
|
|
public override int Read(Span<byte> buffer)
|
|
{
|
|
var read = 0;
|
|
var remainingSpan = buffer;
|
|
|
|
while (remainingSpan.Length > 0)
|
|
{
|
|
if (_currentBuffer.Array == null || _currentBuffer.Count <= 0)
|
|
{
|
|
if (_currentBuffer.Array != null)
|
|
{
|
|
ArrayPool<byte>.Shared.Return(_currentBuffer.Array);
|
|
_currentBuffer = default;
|
|
}
|
|
|
|
if (!_bufferChannel.TryRead(out _currentBuffer))
|
|
{
|
|
// Only block if we haven't read any bytes yet.
|
|
if (read > 0 || !ReadNewBufferSync())
|
|
return read;
|
|
}
|
|
}
|
|
|
|
DebugTools.Assert(_currentBuffer.Array != null);
|
|
|
|
var remainingBuffer = _currentBuffer.Count;
|
|
var thisRead = Math.Min(remainingSpan.Length, remainingBuffer);
|
|
|
|
_currentBuffer.AsSpan(0, thisRead).CopyTo(remainingSpan);
|
|
remainingSpan = remainingSpan[thisRead..];
|
|
_currentBuffer = _currentBuffer[thisRead..];
|
|
read += thisRead;
|
|
}
|
|
|
|
return read;
|
|
}
|
|
|
|
public override async ValueTask<int> ReadAsync(
|
|
Memory<byte> buffer,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
var read = 0;
|
|
var remainingSpan = buffer;
|
|
|
|
while (remainingSpan.Length > 0)
|
|
{
|
|
if (_currentBuffer.Array == null || _currentBuffer.Count <= 0)
|
|
{
|
|
if (_currentBuffer.Array != null)
|
|
{
|
|
ArrayPool<byte>.Shared.Return(_currentBuffer.Array);
|
|
_currentBuffer = default;
|
|
}
|
|
|
|
if (!_bufferChannel.TryRead(out _currentBuffer))
|
|
{
|
|
// Only block if we haven't read any bytes yet.
|
|
if (read > 0 || !await ReadNewBufferAsync())
|
|
return read;
|
|
}
|
|
}
|
|
|
|
DebugTools.Assert(_currentBuffer.Array != null);
|
|
|
|
var remainingBuffer = _currentBuffer.Count;
|
|
var thisRead = Math.Min(remainingSpan.Length, remainingBuffer);
|
|
|
|
_currentBuffer.AsMemory(0, thisRead).CopyTo(remainingSpan);
|
|
remainingSpan = remainingSpan[thisRead..];
|
|
_currentBuffer = _currentBuffer[thisRead..];
|
|
read += thisRead;
|
|
}
|
|
|
|
return read;
|
|
}
|
|
|
|
private bool ReadNewBufferSync()
|
|
{
|
|
DebugTools.Assert(_currentBuffer.Array == null);
|
|
|
|
var waitToRead = _bufferChannel.WaitToReadAsync();
|
|
#pragma warning disable RA0004
|
|
var waitToReadResult = waitToRead.AsTask().Result;
|
|
#pragma warning restore RA0004
|
|
if (!waitToReadResult)
|
|
return false;
|
|
|
|
return _bufferChannel.TryRead(out _currentBuffer);
|
|
}
|
|
|
|
private async Task<bool> ReadNewBufferAsync()
|
|
{
|
|
DebugTools.Assert(_currentBuffer.Array == null);
|
|
|
|
var waitToRead = await _bufferChannel.WaitToReadAsync();
|
|
if (!waitToRead)
|
|
return false;
|
|
|
|
return _bufferChannel.TryRead(out _currentBuffer);
|
|
}
|
|
|
|
protected override void Dispose(bool disposing)
|
|
{
|
|
base.Dispose(disposing);
|
|
|
|
if (disposing && _currentBuffer.Array != null)
|
|
ArrayPool<byte>.Shared.Return(_currentBuffer.Array);
|
|
}
|
|
}
|
|
|
|
protected abstract class ChunkedSendStream : SaneStream
|
|
{
|
|
protected readonly BaseTransferImpl Parent;
|
|
private readonly long _id;
|
|
private readonly string _key;
|
|
|
|
private readonly byte[] _headerBuffer;
|
|
private readonly byte[] _dataBuffer;
|
|
private bool _isFirstTransmission = true;
|
|
private int _bufferPos;
|
|
|
|
public override bool CanWrite => true;
|
|
|
|
public ChunkedSendStream(BaseTransferImpl parent, long id, string key)
|
|
{
|
|
// This just has to be < buffer size & < ushort.MaxValue
|
|
// (when accounting for UTF-8 possibly being more code units than UTF-16)
|
|
if (Encoding.UTF8.GetByteCount(key) > MaxKeySize)
|
|
throw new ArgumentException("Key too long");
|
|
|
|
Parent = parent;
|
|
_id = id;
|
|
_key = key;
|
|
|
|
_headerBuffer = ArrayPool<byte>.Shared.Rent(MaxHeaderSize);
|
|
_dataBuffer = ArrayPool<byte>.Shared.Rent(BufferSize);
|
|
}
|
|
|
|
public override void Write(ReadOnlySpan<byte> buffer)
|
|
{
|
|
while (buffer.Length > 0)
|
|
{
|
|
var remainingBufferSpace = _dataBuffer.AsSpan(_bufferPos);
|
|
var thisChunk = Math.Min(remainingBufferSpace.Length, buffer.Length);
|
|
var thisSpan = buffer[..thisChunk];
|
|
|
|
thisSpan.CopyTo(remainingBufferSpace);
|
|
_bufferPos += thisChunk;
|
|
|
|
if (_bufferPos == _dataBuffer.Length)
|
|
Flush();
|
|
|
|
buffer = buffer[thisChunk..];
|
|
}
|
|
}
|
|
|
|
public override async ValueTask WriteAsync(
|
|
ReadOnlyMemory<byte> buffer,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
while (buffer.Length > 0)
|
|
{
|
|
var remainingBufferSpace = _dataBuffer.AsSpan(_bufferPos);
|
|
var thisChunk = Math.Min(remainingBufferSpace.Length, buffer.Length);
|
|
var thisSpan = buffer[..thisChunk];
|
|
|
|
thisSpan.Span.CopyTo(remainingBufferSpace);
|
|
_bufferPos += thisChunk;
|
|
|
|
if (_bufferPos == _dataBuffer.Length)
|
|
await FlushAsync(cancellationToken).ConfigureAwait(false);
|
|
|
|
buffer = buffer[thisChunk..];
|
|
}
|
|
}
|
|
|
|
public override void Flush()
|
|
{
|
|
FlushAsync().Wait();
|
|
}
|
|
|
|
public override async Task FlushAsync(CancellationToken cancellationToken)
|
|
{
|
|
await FlushAsync(finish: false, cancellationToken).ConfigureAwait(false);
|
|
}
|
|
|
|
private async ValueTask FlushAsync(bool finish, CancellationToken cancel = default)
|
|
{
|
|
var headerLength = 10;
|
|
|
|
var opcode = Opcode.Transfer;
|
|
var flags = TransferFlags.None;
|
|
if (_isFirstTransmission)
|
|
flags |= TransferFlags.Start;
|
|
if (_bufferPos > 0)
|
|
flags |= TransferFlags.HasData;
|
|
if (finish)
|
|
flags |= TransferFlags.Finish;
|
|
|
|
if (flags == TransferFlags.None)
|
|
{
|
|
// Nothing to flush, whatsoever.
|
|
return;
|
|
}
|
|
|
|
_headerBuffer[0] = (byte)opcode;
|
|
_headerBuffer[1] = (byte)flags;
|
|
BinaryPrimitives.WriteInt64LittleEndian(_headerBuffer.AsSpan(2..10), _id);
|
|
|
|
if (_isFirstTransmission)
|
|
{
|
|
var written = Encoding.UTF8.GetBytes(_key, _headerBuffer.AsSpan(11..));
|
|
DebugTools.Assert(written < byte.MaxValue);
|
|
_headerBuffer[10] = (byte)written;
|
|
|
|
headerLength += 1;
|
|
headerLength += written;
|
|
}
|
|
|
|
// Send.
|
|
using (await Parent._socketSemaphore.WaitGuardAsync().ConfigureAwait(false))
|
|
{
|
|
await SendChunkAsync(
|
|
new ArraySegment<byte>(_headerBuffer, 0, headerLength),
|
|
cancel)
|
|
.ConfigureAwait(false);
|
|
|
|
if (_bufferPos > 0)
|
|
{
|
|
await SendChunkAsync(
|
|
new ArraySegment<byte>(_dataBuffer, 0, _bufferPos),
|
|
cancel)
|
|
.ConfigureAwait(false);
|
|
|
|
_bufferPos = 0;
|
|
}
|
|
}
|
|
|
|
_isFirstTransmission = false;
|
|
}
|
|
|
|
protected abstract ValueTask SendChunkAsync(
|
|
ArraySegment<byte> buffer,
|
|
CancellationToken cancellationToken);
|
|
|
|
protected override void Dispose(bool disposing)
|
|
{
|
|
FlushAsync(finish: true).AsTask().Wait();
|
|
DisposeCore();
|
|
}
|
|
|
|
public override async ValueTask DisposeAsync()
|
|
{
|
|
GC.SuppressFinalize(this);
|
|
await FlushAsync(finish: true).ConfigureAwait(false);
|
|
DisposeCore();
|
|
}
|
|
|
|
private void DisposeCore()
|
|
{
|
|
ArrayPool<byte>.Shared.Return(_dataBuffer);
|
|
ArrayPool<byte>.Shared.Return(_headerBuffer);
|
|
}
|
|
|
|
~ChunkedSendStream()
|
|
{
|
|
// Have to do this so the stream isn't permanently hanging on the receiving side.
|
|
FlushAsync(finish: true).AsTask().Wait();
|
|
}
|
|
}
|
|
|
|
public virtual void Dispose()
|
|
{
|
|
foreach (var channel in _receivingChannels.Values)
|
|
{
|
|
channel.Complete();
|
|
}
|
|
}
|
|
|
|
protected enum Opcode : byte
|
|
{
|
|
Transfer = 0,
|
|
}
|
|
|
|
[Flags]
|
|
protected enum TransferFlags : byte
|
|
{
|
|
None = 0,
|
|
Start = 1 << 0,
|
|
Finish = 1 << 1,
|
|
HasData = 1 << 2,
|
|
}
|
|
}
|