* Some stuff for auth

* Holy crap auth works

* Enable encryption even if no auth token is provided.

It's still possible that the public key was retrieved over HTTPS via the status API, in which case it will be secure.

* Fix integration test compile.

* Secure CVar API.

* Literally rewrite the auth protocol to be minecraft's.

* Better exception tolerance in server handshake.

* Auth works from launcher.

* Fix some usages of UserID instead of UserName

* Fix auth.server CVar

* Kick existing connection if same account connects twice.

* Username assignment, guest session distinguishing.

* Necessary work to make bans work.

* Expose LoginType to OnConnecting.

* Fixing tests and warnings.
This commit is contained in:
Pieter-Jan Briers
2020-09-29 14:18:12 +02:00
committed by GitHub
parent e7a49cc1f0
commit aa64528a03
54 changed files with 1270 additions and 389 deletions

View File

@@ -4,7 +4,6 @@ using Robust.Client.Interfaces;
using Robust.Client.Interfaces.Debugging;
using Robust.Client.Interfaces.GameObjects;
using Robust.Client.Interfaces.GameStates;
using Robust.Client.Interfaces.State;
using Robust.Client.Interfaces.Utility;
using Robust.Client.Player;
using Robust.Shared.Enums;
@@ -16,7 +15,6 @@ using Robust.Shared.IoC;
using Robust.Shared.Log;
using Robust.Shared.Network;
using Robust.Shared.Network.Messages;
using Robust.Shared.Players;
using Robust.Shared.Utility;
namespace Robust.Client
@@ -175,16 +173,18 @@ namespace Robust.Client
}
info.ServerMaxPlayers = msg.ServerMaxPlayers;
info.SessionId = msg.PlayerSessionId;
info.TickRate = msg.TickRate;
_timing.TickRate = msg.TickRate;
Logger.InfoS("client", $"Tickrate changed to: {msg.TickRate}");
_discord.Update(info.ServerName, info.SessionId.Username, info.ServerMaxPlayers.ToString());
var userName = msg.MsgChannel.UserName;
var userId = msg.MsgChannel.UserId;
_discord.Update(info.ServerName, userName, info.ServerMaxPlayers.ToString());
// start up player management
_playMan.Startup(_net.ServerChannel!);
_playMan.LocalPlayer!.SessionId = info.SessionId;
_playMan.LocalPlayer!.UserId = userId;
_playMan.LocalPlayer.Name = userName;
_playMan.LocalPlayer.StatusChanged += OnLocalStatusChanged;
}
@@ -312,7 +312,5 @@ namespace Robust.Client
public int ServerMaxPlayers { get; set; }
public byte TickRate { get; internal set; }
public NetSessionId SessionId { get; set; }
}
}

View File

@@ -150,7 +150,7 @@ namespace Robust.Client.Graphics.Clyde
_configurationManager.RegisterCVar("display.ogl_check_errors", false, onValueChanged: b => _checkGLErrors = b);
// This cvar does not modify the actual GL version requested or anything,
// it overrides the version we detect to detect GL features.
_configurationManager.RegisterCVar<string?>("display.ogl_override_version", null);
_configurationManager.RegisterCVar("display.ogl_override_version", "");
RegisterBlockCVars();
}
@@ -240,8 +240,8 @@ namespace Robust.Client.Graphics.Clyde
private (int major, int minor)? ParseGLOverrideVersion()
{
var overrideGLVersion = _configurationManager.GetCVar<string?>("display.ogl_override_version");
if (overrideGLVersion == null)
var overrideGLVersion = _configurationManager.GetCVar<string>("display.ogl_override_version");
if (string.IsNullOrEmpty(overrideGLVersion))
{
return null;
}

View File

@@ -10,7 +10,7 @@ namespace Robust.Client.Player
public interface IPlayerManager
{
IEnumerable<IPlayerSession> Sessions { get; }
IReadOnlyDictionary<NetSessionId, IPlayerSession> SessionsDict { get; }
IReadOnlyDictionary<NetUserId, IPlayerSession> SessionsDict { get; }
LocalPlayer? LocalPlayer { get; }

View File

@@ -36,7 +36,7 @@ namespace Robust.Client.Player
[ViewVariables] public IEntity? ControlledEntity { get; private set; }
[ViewVariables] public NetSessionId SessionId { get; set; }
[ViewVariables] public NetUserId UserId { get; set; }
/// <summary>
/// Session of the local client.
@@ -49,7 +49,8 @@ namespace Robust.Client.Player
/// <summary>
/// OOC name of the local player.
/// </summary>
[ViewVariables] public string Name => SessionId.Username;
[ViewVariables]
public string Name { get; set; } = default!;
/// <summary>
/// The status of the client's session has changed.

View File

@@ -32,7 +32,7 @@ namespace Robust.Client.Player
/// <summary>
/// Active sessions of connected clients to the server.
/// </summary>
private readonly Dictionary<NetSessionId, IPlayerSession> _sessions = new Dictionary<NetSessionId, IPlayerSession>();
private readonly Dictionary<NetUserId, IPlayerSession> _sessions = new Dictionary<NetUserId, IPlayerSession>();
/// <inheritdoc />
public int PlayerCount => _sessions.Values.Count;
@@ -47,7 +47,7 @@ namespace Robust.Client.Player
[ViewVariables] public IEnumerable<IPlayerSession> Sessions => _sessions.Values;
/// <inheritdoc />
public IReadOnlyDictionary<NetSessionId, IPlayerSession> SessionsDict => _sessions;
public IReadOnlyDictionary<NetUserId, IPlayerSession> SessionsDict => _sessions;
/// <inheritdoc />
public event EventHandler? PlayerListUpdated;
@@ -96,7 +96,7 @@ namespace Robust.Client.Player
DebugTools.Assert(LocalPlayer != null, "Call Startup()");
DebugTools.Assert(LocalPlayer!.Session != null, "Received player state before Session finished setup.");
var myState = list.FirstOrDefault(s => s.SessionId == LocalPlayer.SessionId);
var myState = list.FirstOrDefault(s => s.UserId == LocalPlayer.UserId);
if (myState != null)
{
@@ -151,13 +151,13 @@ namespace Robust.Client.Player
{
var dirty = false;
var hitSet = new List<NetSessionId>();
var hitSet = new List<NetUserId>();
foreach (var state in remotePlayers)
{
hitSet.Add(state.SessionId);
hitSet.Add(state.UserId);
if (_sessions.TryGetValue(state.SessionId, out var local))
if (_sessions.TryGetValue(state.UserId, out var local))
{
// Exists, update data.
if (local.Name == state.Name && local.Status == state.Status && local.Ping == state.Ping)
@@ -173,14 +173,14 @@ namespace Robust.Client.Player
// New, give him a slot.
dirty = true;
var newSession = new PlayerSession(state.SessionId)
var newSession = new PlayerSession(state.UserId)
{
Name = state.Name,
Status = state.Status,
Ping = state.Ping
};
_sessions.Add(state.SessionId, newSession);
if (state.SessionId == LocalPlayer!.SessionId)
_sessions.Add(state.UserId, newSession);
if (state.UserId == LocalPlayer!.UserId)
{
LocalPlayer.InternalSession = newSession;
@@ -195,7 +195,7 @@ namespace Robust.Client.Player
// clear slot, player left
if (!hitSet.Contains(existing))
{
DebugTools.Assert(LocalPlayer!.SessionId != existing, "I'm still connected to the server, but i left?");
DebugTools.Assert(LocalPlayer!.UserId != existing, "I'm still connected to the server, but i left?");
_sessions.Remove(existing);
dirty = true;
}

View File

@@ -1,4 +1,5 @@
using Robust.Shared.Enums;
using System;
using Robust.Shared.Enums;
using Robust.Shared.Interfaces.GameObjects;
using Robust.Shared.Network;
@@ -13,7 +14,7 @@ namespace Robust.Client.Player
public IEntity? AttachedEntity { get; set; }
/// <inheritdoc />
public NetSessionId SessionId { get; }
public NetUserId UserId { get; }
/// <inheritdoc cref="IPlayerSession" />
public string Name { get; set; } = "<Unknown>";
@@ -24,9 +25,9 @@ namespace Robust.Client.Player
/// <summary>
/// Creates an instance of a PlayerSession.
/// </summary
public PlayerSession(NetSessionId session)
public PlayerSession(NetUserId user)
{
SessionId = session;
UserId = user;
}
}
}

View File

@@ -117,12 +117,9 @@ namespace Robust.Server.Console.Commands
var name = args[0];
var index = new NetSessionId(name);
if (players.ValidSessionId(index))
if (players.TryGetSessionByUsername(name, out var target))
{
var network = IoCManager.Resolve<IServerNetManager>();
var targetPlyr = players.GetSessionById(index);
var reason = "Kicked by console.";
if (args.Length >= 2)
@@ -130,7 +127,7 @@ namespace Robust.Server.Console.Commands
reason = reason + args[1];
}
network.DisconnectChannel(targetPlyr.ConnectedClient, reason);
network.DisconnectChannel(target.ConnectedClient, reason);
}
}
}

View File

@@ -120,7 +120,7 @@ namespace Robust.Server.GameObjects
if (msgT < cT && _logLateMsgs)
{
Logger.WarningS("net.ent", "Got late MsgEntity! Diff: {0}, msgT: {2}, cT: {3}, player: {1}",
(int) msgT.Value - (int) cT.Value, message.MsgChannel.SessionId, msgT, cT);
(int) msgT.Value - (int) cT.Value, message.MsgChannel.UserName, msgT, cT);
}
DispatchEntityNetworkMessage(message);

View File

@@ -16,7 +16,7 @@ namespace Robust.Server.Interfaces.Player
/// <summary>
/// The session ID of the player owning this data.
/// </summary>
NetSessionId SessionId { get; }
NetUserId UserId { get; }
/// <summary>
/// Custom field that content can assign anything to.

View File

@@ -42,26 +42,38 @@ namespace Robust.Server.Interfaces.Player
/// <param name="maxPlayers">Maximum number of players that can connect to this server at one time.</param>
void Initialize(int maxPlayers);
bool TryGetSessionByUsername(string username, [NotNullWhen(true)] out IPlayerSession? session);
/// <summary>
/// Returns the client session of the networkId.
/// </summary>
/// <returns></returns>
IPlayerSession GetSessionById(NetSessionId index);
IPlayerSession GetSessionByUserId(NetUserId index);
IPlayerSession GetSessionByChannel(INetChannel channel);
bool TryGetSessionByChannel(INetChannel channel, [NotNullWhen(true)] out IPlayerSession? session);
bool TryGetSessionById(NetSessionId sessionId, [NotNullWhen(true)] out IPlayerSession? session);
bool TryGetSessionById(NetUserId userId, [NotNullWhen(true)] out IPlayerSession? session);
/// <summary>
/// Checks to see if a PlayerIndex is a valid session.
/// </summary>
bool ValidSessionId(NetSessionId index);
bool ValidSessionId(NetUserId index);
IPlayerData GetPlayerData(NetSessionId sessionId);
bool TryGetPlayerData(NetSessionId sessionId, [NotNullWhen(true)] out IPlayerData? data);
bool HasPlayerData(NetSessionId sessionId);
IPlayerData GetPlayerData(NetUserId userId);
bool TryGetPlayerData(NetUserId userId, [NotNullWhen(true)] out IPlayerData? data);
bool TryGetPlayerDataByUsername(string userName, [NotNullWhen(true)] out IPlayerData? data);
bool HasPlayerData(NetUserId userId);
/// <summary>
/// Tries to get the user ID of the user with the specified username.
/// </summary>
/// <remarks>
/// This only works if this user has already connected once before during this server run.
/// It does still work if the user has since disconnected.
/// </remarks>
bool TryGetUserId(string userName, out NetUserId userId);
IEnumerable<IPlayerData> GetAllPlayerData();

View File

@@ -3,6 +3,7 @@ using Robust.Server.Player;
using Robust.Shared.GameObjects;
using Robust.Shared.Interfaces.GameObjects;
using Robust.Shared.Interfaces.Network;
using Robust.Shared.Network;
using Robust.Shared.Players;
namespace Robust.Server.Interfaces.Player
@@ -23,6 +24,8 @@ namespace Robust.Server.Interfaces.Player
void JoinGame();
LoginType AuthType { get; }
/// <summary>
/// Attaches this player to an entity.
/// NOTE: The content pack almost certainly has an alternative for this.

View File

@@ -6,13 +6,13 @@ namespace Robust.Server.Player
{
class PlayerData : IPlayerData
{
public PlayerData(NetSessionId sessionId)
public PlayerData(NetUserId userId)
{
SessionId = sessionId;
UserId = userId;
}
[ViewVariables]
public NetSessionId SessionId { get; }
public NetUserId UserId { get; }
[ViewVariables]
public object? ContentDataUncast { get; set; }

View File

@@ -50,10 +50,14 @@ namespace Robust.Server.Player
/// Active sessions of connected clients to the server.
/// </summary>
[ViewVariables]
private readonly Dictionary<NetSessionId, PlayerSession> _sessions = new Dictionary<NetSessionId, PlayerSession>();
private readonly Dictionary<NetUserId, PlayerSession> _sessions = new Dictionary<NetUserId, PlayerSession>();
[ViewVariables]
private readonly Dictionary<NetSessionId, PlayerData> _playerData = new Dictionary<NetSessionId, PlayerData>();
private readonly Dictionary<NetUserId, PlayerData> _playerData = new Dictionary<NetUserId, PlayerData>();
[ViewVariables]
private readonly Dictionary<string, NetUserId> _userIdMap = new Dictionary<string, NetUserId>();
/// <inheritdoc />
[ViewVariables]
@@ -98,6 +102,24 @@ namespace Robust.Server.Player
_network.Disconnect += EndSession;
}
public bool TryGetSessionByUsername(string username, [NotNullWhen(true)] out IPlayerSession? session)
{
if (!_userIdMap.TryGetValue(username, out var userId))
{
session = null;
return false;
}
if (_sessions.TryGetValue(userId, out var iSession))
{
session = iSession;
return true;
}
session = null;
return false;
}
IPlayerSession IPlayerManager.GetSessionByChannel(INetChannel channel) => GetSessionByChannel(channel);
public bool TryGetSessionByChannel(INetChannel channel, [NotNullWhen(true)] out IPlayerSession? session)
{
@@ -105,7 +127,7 @@ namespace Robust.Server.Player
try
{
// Should only be one session per client. Returns that session, in theory.
if (_sessions.TryGetValue(channel.SessionId, out var concrete))
if (_sessions.TryGetValue(channel.UserId, out var concrete))
{
session = concrete;
return true;
@@ -126,7 +148,7 @@ namespace Robust.Server.Player
try
{
// Should only be one session per client. Returns that session, in theory.
return _sessions[channel.SessionId];
return _sessions[channel.UserId];
}
finally
{
@@ -135,7 +157,7 @@ namespace Robust.Server.Player
}
/// <inheritdoc />
public IPlayerSession GetSessionById(NetSessionId index)
public IPlayerSession GetSessionByUserId(NetUserId index)
{
_sessionsLock.EnterReadLock();
try
@@ -148,7 +170,7 @@ namespace Robust.Server.Player
}
}
public bool ValidSessionId(NetSessionId index)
public bool ValidSessionId(NetUserId index)
{
_sessionsLock.EnterReadLock();
try
@@ -161,12 +183,12 @@ namespace Robust.Server.Player
}
}
public bool TryGetSessionById(NetSessionId sessionId, [NotNullWhen(true)] out IPlayerSession? session)
public bool TryGetSessionById(NetUserId userId, [NotNullWhen(true)] out IPlayerSession? session)
{
_sessionsLock.EnterReadLock();
try
{
if (_sessions.TryGetValue(sessionId, out var _session))
if (_sessions.TryGetValue(userId, out var _session))
{
session = _session;
return true;
@@ -197,6 +219,11 @@ namespace Robust.Server.Player
}
}
public bool TryGetUserId(string userName, out NetUserId userId)
{
return _userIdMap.TryGetValue(userName, out userId);
}
public IEnumerable<IPlayerData> GetAllPlayerData()
{
return _playerData.Values;
@@ -242,7 +269,7 @@ namespace Robust.Server.Player
_sessionsLock.ExitReadLock();
}
}
/// <summary>
/// Gets all players inside of a circle.
/// </summary>
@@ -312,10 +339,14 @@ namespace Robust.Server.Player
}
}
private void OnConnecting(object? sender, NetConnectingArgs args)
private Task OnConnecting(NetConnectingArgs args)
{
if (PlayerCount >= _baseServer.MaxPlayers)
args.Deny = true;
{
args.Deny("The server is full.");
}
return Task.CompletedTask;
}
/// <summary>
@@ -325,11 +356,14 @@ namespace Robust.Server.Player
/// <param name="args"></param>
private void NewSession(object? sender, NetChannelArgs args)
{
if (!_playerData.TryGetValue(args.Channel.SessionId, out var data))
if (!_playerData.TryGetValue(args.Channel.UserId, out var data))
{
data = new PlayerData(args.Channel.SessionId);
_playerData.Add(args.Channel.SessionId, data);
data = new PlayerData(args.Channel.UserId);
_playerData.Add(args.Channel.UserId, data);
}
_userIdMap[args.Channel.UserName] = args.Channel.UserId;
var session = new PlayerSession(this, args.Channel, data);
session.PlayerStatusChanged += (obj, sessionArgs) => OnPlayerStatusChanged(session, sessionArgs.OldStatus, sessionArgs.NewStatus);
@@ -337,7 +371,7 @@ namespace Robust.Server.Player
_sessionsLock.EnterWriteLock();
try
{
_sessions.Add(args.Channel.SessionId, session);
_sessions.Add(args.Channel.UserId, session);
}
finally
{
@@ -370,7 +404,7 @@ namespace Robust.Server.Player
_sessionsLock.EnterWriteLock();
try
{
_sessions.Remove(session.SessionId);
_sessions.Remove(session.UserId);
}
finally
{
@@ -381,7 +415,7 @@ namespace Robust.Server.Player
Dirty();
}
private async void HandleWelcomeMessageReq(MsgServerInfoReq message)
private void HandleWelcomeMessageReq(MsgServerInfoReq message)
{
var channel = message.MsgChannel;
var netMsg = channel.CreateNetMessage<MsgServerInfo>();
@@ -390,15 +424,6 @@ namespace Robust.Server.Player
netMsg.ServerMaxPlayers = _baseServer.MaxPlayers;
netMsg.TickRate = _timing.TickRate;
IPlayerSession? session;
while (!TryGetSessionByChannel(channel, out session))
{
await Task.Delay(10);
if (!channel.IsConnected) return;
}
netMsg.PlayerSessionId = session.SessionId;
channel.SendMessage(netMsg);
}
@@ -422,7 +447,7 @@ namespace Robust.Server.Player
var info = new PlayerState
{
SessionId = client.SessionId,
UserId = client.UserId,
Name = client.Name,
Status = client.Status,
Ping = client.ConnectedClient.Ping
@@ -440,14 +465,14 @@ namespace Robust.Server.Player
_lastStateUpdate = _timing.CurTick;
}
public IPlayerData GetPlayerData(NetSessionId sessionId)
public IPlayerData GetPlayerData(NetUserId userId)
{
return _playerData[sessionId];
return _playerData[userId];
}
public bool TryGetPlayerData(NetSessionId sessionId, [NotNullWhen(true)] out IPlayerData? data)
public bool TryGetPlayerData(NetUserId userId, [NotNullWhen(true)] out IPlayerData? data)
{
if (_playerData.TryGetValue(sessionId, out var _data))
if (_playerData.TryGetValue(userId, out var _data))
{
data = _data;
return true;
@@ -456,9 +481,22 @@ namespace Robust.Server.Player
return false;
}
public bool HasPlayerData(NetSessionId sessionId)
public bool TryGetPlayerDataByUsername(string userName, [NotNullWhen(true)] out IPlayerData? data)
{
return _playerData.ContainsKey(sessionId);
if (!_userIdMap.TryGetValue(userName, out var userId))
{
data = null;
return false;
}
// PlayerData is initialized together with the _userIdMap so we can trust that it'll be present.
data = _playerData[userId];
return true;
}
public bool HasPlayerData(NetUserId userId)
{
return _playerData.ContainsKey(userId);
}
}

View File

@@ -22,12 +22,13 @@ namespace Robust.Server.Player
public PlayerSession(PlayerManager playerManager, INetChannel client, PlayerData data)
{
_playerManager = playerManager;
SessionId = client.SessionId;
UserId = client.UserId;
Name = client.UserName;
_data = data;
PlayerState = new PlayerState
{
SessionId = client.SessionId,
UserId = client.UserId,
};
ConnectedClient = client;
@@ -44,7 +45,7 @@ namespace Robust.Server.Player
private SessionStatus _status = SessionStatus.Connecting;
/// <inheritdoc />
public string Name => SessionId.Username;
public string Name { get; }
/// <inheritdoc />
[ViewVariables]
@@ -73,7 +74,7 @@ namespace Robust.Server.Player
/// <inheritdoc />
[ViewVariables]
public NetSessionId SessionId { get; }
public NetUserId UserId { get; }
private readonly PlayerData _data;
[ViewVariables] public IPlayerData Data => _data;
@@ -161,6 +162,8 @@ namespace Robust.Server.Player
UpdatePlayerState();
}
public LoginType AuthType => ConnectedClient.AuthType;
private void UpdatePlayerState()
{
PlayerState.Status = Status;
@@ -176,7 +179,7 @@ namespace Robust.Server.Player
/// <inheritdoc />
public override string ToString()
{
return SessionId.ToString();
return Name;
}
}
}

View File

@@ -18,6 +18,7 @@
<PackageReference Include="Microsoft.Data.Sqlite" Version="3.1.1" />
<PackageReference Include="prometheus-net" Version="3.5.0" />
<PackageReference Include="Serilog.Sinks.Loki" Version="2.1.0" />
<PackageReference Include="System.IdentityModel.Tokens.Jwt" Version="6.7.1" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Lidgren.Network\Lidgren.Network.csproj" />

View File

@@ -1,3 +1,4 @@
using System;
using System.IO;
using System.Net;
using System.Net.Http;
@@ -88,7 +89,7 @@ namespace Robust.Server.ServerStatus
JObject? buildInfo;
if (downloadUrlWindows == null)
if (string.IsNullOrEmpty(downloadUrlWindows))
{
buildInfo = null;
}
@@ -113,9 +114,18 @@ namespace Robust.Server.ServerStatus
};
}
var authInfo = new JObject
{
["mode"] = _netManager.Auth.ToString(),
["public_key"] = _netManager.RsaPublicKey != null
? Convert.ToBase64String(_netManager.RsaPublicKey)
: null
};
var jObject = new JObject
{
["connect_address"] = _configurationManager.GetCVar<string>("status.connectaddress"),
["auth"] = authInfo,
["build"] = buildInfo
};
@@ -131,7 +141,6 @@ namespace Robust.Server.ServerStatus
return true;
}
}
}

View File

@@ -19,6 +19,7 @@ using Robust.Server.Interfaces.ServerStatus;
using Robust.Shared.Configuration;
using Robust.Shared.ContentPack;
using Robust.Shared.Interfaces.Configuration;
using Robust.Shared.Interfaces.Network;
using Robust.Shared.IoC;
using Robust.Shared.Log;
@@ -43,6 +44,7 @@ namespace Robust.Server.ServerStatus
private readonly List<StatusHostHandler> _handlers = new List<StatusHostHandler>();
[Dependency] private readonly IConfigurationManager _configurationManager = default!;
[Dependency] private readonly IServerNetManager _netManager = default!;
private KestrelServer _server = default!;
@@ -171,16 +173,16 @@ namespace Robust.Server.ServerStatus
_configurationManager.RegisterCVar("status.enabled", true, CVar.ARCHIVE);
_configurationManager.RegisterCVar("status.bind", "*:1212", CVar.ARCHIVE);
_configurationManager.RegisterCVar<string?>("status.connectaddress", null, CVar.ARCHIVE);
_configurationManager.RegisterCVar("status.connectaddress", "", CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.fork_id", info?.ForkId, CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.version", info?.Version, CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.download_url_windows", info?.Downloads.Windows, CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.download_url_macos", info?.Downloads.MacOS, CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.download_url_linux", info?.Downloads.Linux, CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.hash_windows", info?.Hashes.Windows, CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.hash_macos", info?.Hashes.MacOS, CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.hash_linux", info?.Hashes.Linux, CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.fork_id", info?.ForkId ?? "", CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.version", info?.Version ?? "", CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.download_url_windows", info?.Downloads.Windows ?? "", CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.download_url_macos", info?.Downloads.MacOS ?? "", CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.download_url_linux", info?.Downloads.Linux ?? "", CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.hash_windows", info?.Hashes.Windows ?? "", CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.hash_macos", info?.Hashes.MacOS ?? "", CVar.ARCHIVE);
_configurationManager.RegisterCVar("build.hash_linux", info?.Hashes.Linux ?? "", CVar.ARCHIVE);
}
[JsonObject(ItemRequired = Required.DisallowNull)]

View File

@@ -40,8 +40,8 @@ namespace Robust.Server.ServerStatus
_statusHost.AddHandler(UpdateHandler);
_statusHost.AddHandler(ShutdownHandler);
_configurationManager.RegisterCVar<string?>("watchdog.token", null, onValueChanged: _ => UpdateToken());
_configurationManager.RegisterCVar<string?>("watchdog.key", null, onValueChanged: _ => UpdateToken());
_configurationManager.RegisterCVar("watchdog.token", "", onValueChanged: _ => UpdateToken());
_configurationManager.RegisterCVar("watchdog.key", "", onValueChanged: _ => UpdateToken());
_configurationManager.RegisterCVar("watchdog.baseUrl", "http://localhost:5000");
}
@@ -174,10 +174,12 @@ namespace Robust.Server.ServerStatus
private void UpdateToken()
{
_watchdogToken = _configurationManager.GetCVar<string?>("watchdog.token");
_watchdogKey = _configurationManager.GetCVar<string?>("watchdog.key");
var baseUrl = _configurationManager.GetCVar<string?>("watchdog.baseUrl");
_baseUri = baseUrl != null ? new Uri(baseUrl) : null;
var tok = _configurationManager.GetCVar<string>("watchdog.token");
var key = _configurationManager.GetCVar<string>("watchdog.key");
var baseUrl = _configurationManager.GetCVar<string>("watchdog.baseUrl");
_watchdogToken = string.IsNullOrEmpty(_watchdogToken) ? null : tok;
_watchdogKey = string.IsNullOrEmpty(_watchdogKey) ? null : key;
_baseUri = string.IsNullOrEmpty(baseUrl) ? null : new Uri(baseUrl);
if (_watchdogKey != null && _watchdogToken != null)
{

View File

@@ -47,7 +47,7 @@ namespace Robust.Server.ViewVariables
private void _msgCloseSession(MsgViewVariablesCloseSession message)
{
if (!_sessions.TryGetValue(message.SessionId, out var session)
|| session.PlayerSession != message.MsgChannel.SessionId)
|| session.PlayerUser != message.MsgChannel.UserId)
{
// TODO: logging?
return;
@@ -59,7 +59,7 @@ namespace Robust.Server.ViewVariables
private void _msgModifyRemote(MsgViewVariablesModifyRemote message)
{
if (!_sessions.TryGetValue(message.SessionId, out var session)
|| session.PlayerSession != message.MsgChannel.SessionId)
|| session.PlayerUser != message.MsgChannel.UserId)
{
// TODO: logging?
return;
@@ -77,7 +77,7 @@ namespace Robust.Server.ViewVariables
private void _msgReqData(MsgViewVariablesReqData message)
{
if (!_sessions.TryGetValue(message.SessionId, out var session)
|| session.PlayerSession != message.MsgChannel.SessionId)
|| session.PlayerUser != message.MsgChannel.UserId)
{
// TODO: logging?
return;
@@ -136,7 +136,7 @@ namespace Robust.Server.ViewVariables
}
case ViewVariablesSessionRelativeSelector sessionRelativeSelector:
if (!_sessions.TryGetValue(sessionRelativeSelector.SessionId, out var relSession)
|| relSession.PlayerSession != message.MsgChannel.SessionId)
|| relSession.PlayerUser != message.MsgChannel.UserId)
{
// TODO: logging?
Deny(DenyReason.NoObject);
@@ -190,7 +190,7 @@ namespace Robust.Server.ViewVariables
}
var sessionId = _nextSessionId++;
var session = new ViewVariablesSession(message.MsgChannel.SessionId, theObject, sessionId, this,
var session = new ViewVariablesSession(message.MsgChannel.UserId, theObject, sessionId, this,
_robustSerializer);
_sessions.Add(sessionId, session);
@@ -217,7 +217,7 @@ namespace Robust.Server.ViewVariables
}
_sessions.Remove(sessionId);
if (!sendMsg || !_playerManager.TryGetSessionById(session.PlayerSession, out var player) ||
if (!sendMsg || !_playerManager.TryGetSessionById(session.PlayerUser, out var player) ||
player.Status == SessionStatus.Disconnected)
{
return;

View File

@@ -14,21 +14,21 @@ namespace Robust.Server.ViewVariables
private readonly List<ViewVariablesTrait> _traits = new List<ViewVariablesTrait>();
public IViewVariablesHost Host { get; }
public IRobustSerializer RobustSerializer { get; }
public NetSessionId PlayerSession { get; }
public NetUserId PlayerUser { get; }
public object Object { get; }
public uint SessionId { get; }
public Type ObjectType { get; }
/// <param name="playerSession">The session ID of the player who opened this session.</param>
/// <param name="playerUser">The session ID of the player who opened this session.</param>
/// <param name="o">The object we represent.</param>
/// <param name="sessionId">
/// The session ID for this session. This is what the server and client use to talk about this session.
/// </param>
/// <param name="host">The view variables host owning this session.</param>
public ViewVariablesSession(NetSessionId playerSession, object o, uint sessionId, IViewVariablesHost host,
public ViewVariablesSession(NetUserId playerUser, object o, uint sessionId, IViewVariablesHost host,
IRobustSerializer robustSerializer)
{
PlayerSession = playerSession;
PlayerUser = playerUser;
Object = o;
SessionId = sessionId;
ObjectType = o.GetType();

View File

@@ -11,7 +11,8 @@ namespace Robust.Shared
throw new InvalidOperationException("This class must not be instantiated");
}
public static readonly CVarDef<int> NetPort = CVarDef.Create("net.port", 1212, CVar.ARCHIVE);
public static readonly CVarDef<int> NetPort =
CVarDef.Create("net.port", 1212, CVar.ARCHIVE);
public static readonly CVarDef<int> NetSendBufferSize =
CVarDef.Create("net.sendbuffersize", 131071, CVar.ARCHIVE);
@@ -44,6 +45,24 @@ namespace Robust.Shared
public static readonly CVarDef<int> GameMaxPlayers =
CVarDef.Create("game.maxplayers", 32, CVar.ARCHIVE | CVar.SERVERONLY);
public static readonly CVarDef<int> AuthMode =
CVarDef.Create("auth.mode", (int) Network.AuthMode.Optional, CVar.SERVERONLY);
public static readonly CVarDef<bool> AuthAllowLocal =
CVarDef.Create("auth.allowlocal", true, CVar.SERVERONLY);
public static readonly CVarDef<string> AuthServerPubKey =
CVarDef.Create("auth.serverpubkey", "", CVar.SECURE | CVar.CLIENTONLY);
public static readonly CVarDef<string> AuthToken =
CVarDef.Create("auth.token", "", CVar.SECURE | CVar.CLIENTONLY);
public static readonly CVarDef<string> AuthUserId =
CVarDef.Create("auth.userid", "", CVar.SECURE | CVar.CLIENTONLY);
public static readonly CVarDef<string> AuthServer =
CVarDef.Create("auth.server", "http://localhost:5000/", CVar.SECURE);
#if DEBUG
public static readonly CVarDef<float> NetFakeLoss = CVarDef.Create("net.fakeloss", 0f, CVar.CHEAT);
public static readonly CVarDef<float> NetFakeLagMin = CVarDef.Create("net.fakelagmin", 0f, CVar.CHEAT);

View File

@@ -58,5 +58,10 @@ namespace Robust.Shared.Configuration
/// This is intended to aid shared code.
/// </remarks>
CLIENTONLY = 128,
/// <summary>
/// This var has to kept secure and may not be accessed by content.
/// </summary>
SECURE = 256
}
}

View File

@@ -12,7 +12,7 @@ namespace Robust.Shared.Configuration
/// <summary>
/// Stores and manages global configuration variables.
/// </summary>
public class ConfigurationManager : IConfigurationManager
public class ConfigurationManager : IConfigurationManagerInternal
{
private const char TABLE_DELIMITER = '.';
private readonly Dictionary<string, ConfigVar> _configVars = new Dictionary<string, ConfigVar>();
@@ -276,20 +276,22 @@ namespace Robust.Shared.Configuration
/// <inheritdoc />
public bool IsCVarRegistered(string name)
{
return _configVars.TryGetValue(name, out var cVar) && cVar.Registered;
return _configVars.TryGetValue(name, out var cVar) && cVar.Registered && (cVar.Flags & CVar.SECURE) == 0;
}
/// <inheritdoc />
public IEnumerable<string> GetRegisteredCVars()
{
return _configVars.Keys;
return _configVars
.Where(p => (p.Value.Flags & CVar.SECURE) == 0)
.Select(p => p.Key);
}
/// <inheritdoc />
public void SetCVar(string name, object value)
{
//TODO: Make flags work, required non-derpy net system.
if (_configVars.TryGetValue(name, out var cVar) && cVar.Registered)
if (_configVars.TryGetValue(name, out var cVar) && cVar.Registered && (cVar.Flags & CVar.SECURE) == 0)
{
if (!Equals(cVar.Value, value))
{
@@ -307,6 +309,15 @@ namespace Robust.Shared.Configuration
/// <inheritdoc />
public T GetCVar<T>(string name)
{
if (_configVars.TryGetValue(name, out var cVar) && cVar.Registered && (cVar.Flags & CVar.SECURE) == 0)
//TODO: Make flags work, required non-derpy net system.
return (T)(cVar.OverrideValueParsed ?? cVar.Value ?? cVar.DefaultValue)!;
throw new InvalidConfigurationException($"Trying to get unregistered variable '{name}'");
}
public T GetSecureCVar<T>(string name)
{
if (_configVars.TryGetValue(name, out var cVar) && cVar.Registered)
//TODO: Make flags work, required non-derpy net system.
@@ -322,7 +333,10 @@ namespace Robust.Shared.Configuration
public Type GetCVarType(string name)
{
var cVar = _configVars[name];
if (!_configVars.TryGetValue(name, out var cVar) || !cVar.Registered || (cVar.Flags & CVar.SECURE) != 0)
{
throw new InvalidConfigurationException($"Trying to get type of unregistered variable '{name}'");
}
// If it's null it's a string, since the rest is primitives which aren't null.
return cVar.Value?.GetType() ?? typeof(string);

View File

@@ -12,7 +12,7 @@ namespace Robust.Shared.GameStates
[Serializable, NetSerializable]
public sealed class PlayerState
{
public NetSessionId SessionId { get; set; }
public NetUserId UserId { get; set; }
public string Name { get; set; }
public SessionStatus Status { get; set; }

View File

@@ -0,0 +1,7 @@
namespace Robust.Shared.Interfaces.Configuration
{
internal interface IConfigurationManagerInternal : IConfigurationManager
{
T GetSecureCVar<T>(string name);
}
}

View File

@@ -28,7 +28,11 @@ namespace Robust.Shared.Interfaces.Network
/// On the server, this is the session ID for this client.
/// On the client, this is the session ID for the client.
/// </summary>
NetSessionId SessionId { get; }
NetUserId UserId { get; }
string UserName { get; }
LoginType AuthType { get; }
/// <summary>
/// Average round trip time in milliseconds between the remote peer and us.

View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Robust.Shared.Network;
namespace Robust.Shared.Interfaces.Network
@@ -104,7 +105,7 @@ namespace Robust.Shared.Interfaces.Network
/// <summary>
/// An incoming connection is being received.
/// </summary>
event EventHandler<NetConnectingArgs> Connecting;
event Func<NetConnectingArgs, Task> Connecting;
/// <summary>
/// A client has just connected to the server.

View File

@@ -1,4 +1,8 @@
using System;
using System.Threading.Tasks;
using Robust.Shared.Network;
namespace Robust.Shared.Interfaces.Network
{
/// <summary>
@@ -6,6 +10,13 @@ namespace Robust.Shared.Interfaces.Network
/// </summary>
public interface IServerNetManager : INetManager
{
public delegate Task<NetApproval> NetApprovalDelegate(NetApprovalEventArgs eventArgs);
byte[]? RsaPublicKey { get; }
AuthMode Auth { get; }
Func<string, Task<NetUserId?>>? AssignUserIdCallback { get; set; }
NetApprovalDelegate? HandleApprovalCallback { get; set; }
/// <summary>
/// Disconnects this channel from the remote peer.
/// </summary>

View File

@@ -0,0 +1,30 @@
using System;
namespace Robust.Shared.Interfaces.Network
{
public struct NetApproval
{
public bool IsApproved => _denyReason == null;
public string DenyReason => _denyReason != null
? _denyReason!
: throw new InvalidOperationException("This was not a denial.");
private readonly string? _denyReason;
private NetApproval(string? denyReason)
{
_denyReason = denyReason;
}
public static NetApproval Deny(string reason)
{
return new NetApproval(reason);
}
public static NetApproval Allow()
{
return new NetApproval(null);
}
}
}

View File

@@ -0,0 +1,15 @@
using System;
using Lidgren.Network;
namespace Robust.Shared.Interfaces.Network
{
public sealed class NetApprovalEventArgs : EventArgs
{
public NetConnection Connection { get; }
public NetApprovalEventArgs(NetConnection connection)
{
Connection = connection;
}
}
}

View File

@@ -0,0 +1,9 @@
namespace Robust.Shared.Network
{
public enum AuthMode
{
Optional = 0,
Required = 1,
Disabled = 2
}
}

View File

@@ -0,0 +1,28 @@
namespace Robust.Shared.Network
{
public enum LoginType : byte
{
/// <summary>
/// This player is not logged in and as soon as they disconnect their data will be gone, probably.
/// </summary>
Guest = 0,
/// <summary>
/// This player is properly logged in with an auth account.
/// </summary>
LoggedIn = 1,
/// <summary>
/// This player is not logged in but their username does have a static user ID assigned.
/// </summary>
GuestAssigned = 2
}
public static class LoginTypeExt
{
public static bool HasStaticUserId(this LoginType type)
{
return type == LoginType.LoggedIn || type == LoginType.GuestAssigned;
}
}
}

View File

@@ -0,0 +1,32 @@
using Lidgren.Network;
#nullable disable
namespace Robust.Shared.Network.Messages
{
internal sealed class MsgEncryptionRequest : NetMessage
{
public MsgEncryptionRequest() : base("", MsgGroups.Core)
{
}
public byte[] VerifyToken;
public byte[] PublicKey;
public override void ReadFromBuffer(NetIncomingMessage buffer)
{
var tokenLength = buffer.ReadVariableInt32();
VerifyToken = buffer.ReadBytes(tokenLength);
var keyLength = buffer.ReadVariableInt32();
PublicKey = buffer.ReadBytes(keyLength);
}
public override void WriteToBuffer(NetOutgoingMessage buffer)
{
buffer.WriteVariableInt32(VerifyToken.Length);
buffer.Write(VerifyToken);
buffer.WriteVariableInt32(PublicKey.Length);
buffer.Write(PublicKey);
}
}
}

View File

@@ -0,0 +1,36 @@
using System;
using Lidgren.Network;
#nullable disable
namespace Robust.Shared.Network.Messages
{
internal sealed class MsgEncryptionResponse : NetMessage
{
public MsgEncryptionResponse() : base("", MsgGroups.Core)
{
}
public Guid UserId;
public byte[] SharedSecret;
public byte[] VerifyToken;
public override void ReadFromBuffer(NetIncomingMessage buffer)
{
UserId = buffer.ReadGuid();
var keyLength = buffer.ReadVariableInt32();
SharedSecret = buffer.ReadBytes(keyLength);
var tokenLength = buffer.ReadVariableInt32();
VerifyToken = buffer.ReadBytes(tokenLength);
}
public override void WriteToBuffer(NetOutgoingMessage buffer)
{
buffer.Write(UserId);
buffer.WriteVariableInt32(SharedSecret.Length);
buffer.Write(SharedSecret);
buffer.WriteVariableInt32(VerifyToken.Length);
buffer.Write(VerifyToken);
}
}
}

View File

@@ -0,0 +1,34 @@
using Lidgren.Network;
#nullable disable
namespace Robust.Shared.Network.Messages
{
internal sealed class MsgLoginStart : NetMessage
{
// **NOTE**: This is a special message sent during the client<->server handshake.
// It doesn't actually get sent normally and as such doesn't have the "normal" boilerplate.
// It's basically just a sane way to encapsulate the message write/read logic.
public MsgLoginStart() : base("", MsgGroups.Core)
{
}
public string UserName;
public bool CanAuth;
public bool NeedPubKey;
public override void ReadFromBuffer(NetIncomingMessage buffer)
{
UserName = buffer.ReadString();
CanAuth = buffer.ReadBoolean();
NeedPubKey = buffer.ReadBoolean();
}
public override void WriteToBuffer(NetOutgoingMessage buffer)
{
buffer.Write(UserName);
buffer.Write(CanAuth);
buffer.Write(NeedPubKey);
}
}
}

View File

@@ -0,0 +1,34 @@
using System;
using Lidgren.Network;
#nullable disable
namespace Robust.Shared.Network.Messages
{
internal sealed class MsgLoginSuccess : NetMessage
{
// Same deal as MsgLogin, helper for NetManager only.
public MsgLoginSuccess() : base("", MsgGroups.Core)
{
}
public string UserName;
public Guid UserId;
public LoginType Type;
public override void ReadFromBuffer(NetIncomingMessage buffer)
{
UserName = buffer.ReadString();
UserId = buffer.ReadGuid();
Type = (LoginType) buffer.ReadByte();
}
public override void WriteToBuffer(NetOutgoingMessage buffer)
{
buffer.Write(UserName);
buffer.Write(UserId);
buffer.Write((byte) Type);
}
}
}

View File

@@ -28,7 +28,7 @@ namespace Robust.Shared.Network.Messages
{
var plyNfo = new PlayerState
{
SessionId = new NetSessionId(buffer.ReadString()),
UserId = new NetUserId(buffer.ReadGuid()),
Name = buffer.ReadString(),
Status = (SessionStatus)buffer.ReadByte(),
Ping = buffer.ReadInt16()
@@ -43,7 +43,7 @@ namespace Robust.Shared.Network.Messages
foreach (var ply in Plyrs)
{
buffer.Write(ply.SessionId.Username);
buffer.Write(ply.UserId.UserId);
buffer.Write(ply.Name);
buffer.Write((byte) ply.Status);
buffer.Write(ply.Ping);

View File

@@ -16,14 +16,12 @@ namespace Robust.Shared.Network.Messages
public string ServerName { get; set; }
public int ServerMaxPlayers { get; set; }
public byte TickRate { get; set; }
public NetSessionId PlayerSessionId { get; set; }
public override void ReadFromBuffer(NetIncomingMessage buffer)
{
ServerName = buffer.ReadString();
ServerMaxPlayers = buffer.ReadInt32();
TickRate = buffer.ReadByte();
PlayerSessionId = new NetSessionId(buffer.ReadString());
}
public override void WriteToBuffer(NetOutgoingMessage buffer)
@@ -31,7 +29,6 @@ namespace Robust.Shared.Network.Messages
buffer.Write(ServerName);
buffer.Write(ServerMaxPlayers);
buffer.Write(TickRate);
buffer.Write(PlayerSessionId.Username);
}
}
}

View File

@@ -2,7 +2,6 @@
using System.Net;
using Lidgren.Network;
using Robust.Shared.Interfaces.Network;
using Robust.Shared.Utility;
namespace Robust.Shared.Network
{
@@ -20,6 +19,9 @@ namespace Robust.Shared.Network
/// <inheritdoc />
public INetManager NetPeer => _manager;
public string UserName { get; }
public LoginType AuthType { get; }
/// <inheritdoc />
public short Ping => (short) Math.Round(_connection.AverageRoundtripTime * 1000);
@@ -34,18 +36,23 @@ namespace Robust.Shared.Network
/// </summary>
public NetConnection Connection => _connection;
public NetSessionId SessionId { get; }
public NetUserId UserId { get; }
// Only used on server, contains the encryption to use for this channel.
public NetEncryption? Encryption { get; set; }
/// <summary>
/// Creates a new instance of a NetChannel.
/// </summary>
/// <param name="manager">The server this channel belongs to.</param>
/// <param name="connection">The raw NetConnection to the remote peer.</param>
internal NetChannel(NetManager manager, NetConnection connection, NetSessionId sessionId)
internal NetChannel(NetManager manager, NetConnection connection, NetUserId userId, string userName, LoginType loginType)
{
_manager = manager;
_connection = connection;
SessionId = sessionId;
UserId = userId;
UserName = userName;
AuthType = loginType;
}
/// <inheritdoc />

View File

@@ -29,26 +29,34 @@ namespace Robust.Shared.Network
/// </summary>
public class NetConnectingArgs : EventArgs
{
/// <summary>
/// If this is set to true, deny the incoming connection.
/// </summary>
public bool Deny { get; set; } = false;
public bool IsDenied => DenyReason != null;
public string? DenyReason { get; private set; }
/// <summary>
/// The IP of the incoming connection.
/// </summary>
public readonly NetSessionId SessionId;
public readonly NetUserId UserId;
public readonly IPEndPoint IP;
public readonly string UserName;
public readonly LoginType AuthType;
public void Deny(string reason)
{
DenyReason = reason;
}
/// <summary>
/// Constructs a new instance.
/// </summary>
/// <param name="sessionId">The session ID of the incoming connection.</param>
public NetConnectingArgs(NetSessionId sessionId, IPEndPoint ip)
/// <param name="userId">The session ID of the incoming connection.</param>
public NetConnectingArgs(NetUserId userId, IPEndPoint ip, string userName, LoginType authType)
{
SessionId = sessionId;
UserId = userId;
IP = ip;
UserName = userName;
AuthType = authType;
}
}

View File

@@ -2,12 +2,18 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Net.Mime;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Threading;
using System.Threading.Tasks;
using Lidgren.Network;
using Newtonsoft.Json;
using Robust.Shared.Interfaces.Network;
using Robust.Shared.Log;
using Robust.Shared.Network.Messages;
using Robust.Shared.Utility;
namespace Robust.Shared.Network
@@ -62,20 +68,176 @@ namespace Robust.Shared.Network
Logger.DebugS("net", "Attempting to connect to {0} port {1}", host, port);
// Get list of potential IP addresses for the domain.
var endPoints = await ResolveDnsAsync(host);
if (mainCancelToken.IsCancellationRequested)
var resolveResult = await CCResolveHost(host, mainCancelToken);
if (resolveResult == null)
{
ClientConnectState = ClientConnectionState.NotConnecting;
return;
}
var (first, second) = resolveResult.Value;
ClientConnectState = ClientConnectionState.EstablishingConnection;
Logger.DebugS("net", "First attempt IP address is {0}, second attempt {1}", first, second);
var result = await CCHappyEyeballs(port, first, second, mainCancelToken);
if (result == null)
{
ClientConnectState = ClientConnectionState.NotConnecting;
return;
}
var (winningPeer, winningConnection) = result.Value;
ClientConnectState = ClientConnectionState.Handshake;
// We're connected start handshaking.
try
{
await CCDoHandshake(winningPeer, winningConnection, userNameRequest, mainCancelToken);
}
catch (TaskCanceledException)
{
winningPeer.Peer.Shutdown("Cancelled");
_toCleanNetPeers.Add(winningPeer.Peer);
ClientConnectState = ClientConnectionState.NotConnecting;
return;
}
catch (Exception e)
{
OnConnectFailed(e.Message);
Logger.ErrorS("net", "Exception during handshake: {0}", e);
winningPeer.Peer.Shutdown("Something happened.");
_toCleanNetPeers.Add(winningPeer.Peer);
ClientConnectState = ClientConnectionState.NotConnecting;
return;
}
ClientConnectState = ClientConnectionState.Connected;
Logger.DebugS("net", "Handshake completed, connection established.");
}
private async Task CCDoHandshake(NetPeerData peer, NetConnection connection, string userNameRequest,
CancellationToken cancel)
{
var authToken = _config.GetSecureCVar<string>("auth.token");
var pubKey = _config.GetSecureCVar<string>("auth.serverpubkey");
var authServer = _config.GetSecureCVar<string>("auth.server");
var userIdStr = _config.GetSecureCVar<string>("auth.userid");
var hasPubKey = !string.IsNullOrEmpty(pubKey);
var authenticate = !string.IsNullOrEmpty(authToken);
var msgLogin = new MsgLoginStart
{
UserName = userNameRequest,
CanAuth = authenticate,
NeedPubKey = !hasPubKey
};
var outLoginMsg = peer.Peer.CreateMessage();
msgLogin.WriteToBuffer(outLoginMsg);
peer.Peer.SendMessage(outLoginMsg, connection, NetDeliveryMethod.ReliableOrdered);
NetEncryption? encryption = null;
var response = await AwaitData(connection, cancel);
var loginSuccess = response.ReadBoolean();
response.ReadPadBits();
if (!loginSuccess)
{
// Need to authenticate, packet is MsgEncryptionRequest
var encRequest = new MsgEncryptionRequest();
encRequest.ReadFromBuffer(response);
var sharedSecret = new byte[AesKeyLength];
RandomNumberGenerator.Fill(sharedSecret);
encryption = new NetAESEncryption(peer.Peer, sharedSecret, 0, sharedSecret.Length);
byte[] keyBytes;
if (hasPubKey)
{
// public key provided by launcher.
keyBytes = Convert.FromBase64String(pubKey);
}
else
{
// public key is gotten from handshake.
keyBytes = encRequest.PublicKey;
}
var rsaKey = RSA.Create();
rsaKey.ImportRSAPublicKey(keyBytes, out _);
var encryptedSecret = rsaKey.Encrypt(sharedSecret, RSAEncryptionPadding.OaepSHA256);
var encryptedVerifyToken = rsaKey.Encrypt(encRequest.VerifyToken, RSAEncryptionPadding.OaepSHA256);
var authHashBytes = MakeAuthHash(sharedSecret, keyBytes);
var authHash = Convert.ToBase64String(authHashBytes);
var joinReq = new JoinRequest {Hash = authHash};
var httpClient = new HttpClient();
httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("SS14Auth", authToken);
var joinJson = JsonConvert.SerializeObject(joinReq);
var joinResp = await httpClient.PostAsync(authServer + "api/session/join",
new StringContent(joinJson, EncodingHelpers.UTF8, MediaTypeNames.Application.Json), cancel);
joinResp.EnsureSuccessStatusCode();
var encryptionResponse = new MsgEncryptionResponse
{
SharedSecret = encryptedSecret,
VerifyToken = encryptedVerifyToken,
UserId = new Guid(userIdStr)
};
var outEncRespMsg = peer.Peer.CreateMessage();
encryptionResponse.WriteToBuffer(outEncRespMsg);
peer.Peer.SendMessage(outEncRespMsg, connection, NetDeliveryMethod.ReliableOrdered);
// Expect login success here.
response = await AwaitData(connection, cancel);
encryption.Decrypt(response);
}
var msgSuc = new MsgLoginSuccess();
msgSuc.ReadFromBuffer(response);
var channel = new NetChannel(this, connection, new NetUserId(msgSuc.UserId), msgSuc.UserName, msgSuc.Type);
_channels.Add(connection, channel);
peer.AddChannel(channel);
_clientEncryption = encryption;
}
private static byte[] MakeAuthHash(byte[] sharedSecret, byte[] pkBytes)
{
Logger.DebugS("auth", "shared: {0}, pk: {1}", Convert.ToBase64String(sharedSecret), Convert.ToBase64String(pkBytes));
var incHash = IncrementalHash.CreateHash(HashAlgorithmName.SHA256);
incHash.AppendData(sharedSecret);
incHash.AppendData(pkBytes);
return incHash.GetHashAndReset();
}
private async Task<(IPAddress first, IPAddress? second)?>
CCResolveHost(string host, CancellationToken mainCancelToken)
{
// Get list of potential IP addresses for the domain.
var endPoints = await ResolveDnsAsync(host);
if (mainCancelToken.IsCancellationRequested)
{
return null;
}
if (endPoints == null)
{
OnConnectFailed($"Unable to resolve domain '{host}'");
ClientConnectState = ClientConnectionState.NotConnecting;
return;
return null;
}
// Try to get an IPv6 and IPv4 address.
@@ -85,12 +247,9 @@ namespace Robust.Shared.Network
if (ipv4 == null && ipv6 == null)
{
OnConnectFailed($"Domain '{host}' has no associated IP addresses");
ClientConnectState = ClientConnectionState.NotConnecting;
return;
return null;
}
ClientConnectState = ClientConnectionState.EstablishingConnection;
IPAddress first;
IPAddress? second = null;
if (ipv6 != null)
@@ -104,8 +263,12 @@ namespace Robust.Shared.Network
first = ipv4!;
}
Logger.DebugS("net", "First attempt IP address is {0}, second attempt {1}", first, second);
return (first, second);
}
private async Task<(NetPeerData winningPeer, NetConnection winningConnection)?>
CCHappyEyeballs(int port, IPAddress first, IPAddress? second, CancellationToken mainCancelToken)
{
NetPeerData CreatePeerForIp(IPAddress address)
{
var config = _getBaseNetPeerConfig();
@@ -249,8 +412,7 @@ namespace Robust.Shared.Network
_toCleanNetPeers.Add(secondPeer.Peer);
}
ClientConnectState = ClientConnectionState.NotConnecting;
return;
return null;
}
// winningPeer can still be failed at this point.
@@ -260,48 +422,10 @@ namespace Robust.Shared.Network
winningPeer!.Peer.Shutdown("You failed");
_toCleanNetPeers.Add(winningPeer.Peer);
OnConnectFailed((secondReason ?? firstReason)!);
ClientConnectState = ClientConnectionState.NotConnecting;
return;
return null;
}
ClientConnectState = ClientConnectionState.Handshake;
// We're connected start handshaking.
var userNameRequestMsg = winningPeer!.Peer.CreateMessage(userNameRequest);
winningPeer.Peer.SendMessage(userNameRequestMsg, winningConnection, NetDeliveryMethod.ReliableOrdered);
try
{
// Await response.
var response = await AwaitData(winningConnection, mainCancelToken);
var receivedUsername = response.ReadString();
var channel = new NetChannel(this, winningConnection, new NetSessionId(receivedUsername));
_channels.Add(winningConnection, channel);
winningPeer.AddChannel(channel);
var confirmConnectionMsg = winningPeer.Peer.CreateMessage("ok");
winningPeer.Peer.SendMessage(confirmConnectionMsg, winningConnection, NetDeliveryMethod.ReliableOrdered);
}
catch (TaskCanceledException)
{
winningPeer.Peer.Shutdown("Cancelled");
_toCleanNetPeers.Add(secondPeer!.Peer);
ClientConnectState = ClientConnectionState.NotConnecting;
return;
}
catch (Exception e)
{
OnConnectFailed(e.Message);
Logger.ErrorS("net", "Exception during handshake: {0}", e);
winningPeer.Peer.Shutdown("Something happened.");
_toCleanNetPeers.Add(secondPeer!.Peer);
ClientConnectState = ClientConnectionState.NotConnecting;
return;
}
ClientConnectState = ClientConnectionState.Connected;
Logger.DebugS("net", "Handshake completed, connection established.");
return (winningPeer!, winningConnection);
}
private Task<string> AwaitStatusChange(NetConnection connection, CancellationToken cancellationToken = default)
@@ -334,6 +458,9 @@ namespace Robust.Shared.Network
throw new InvalidOperationException("Cannot await data twice.");
}
DebugTools.Assert(!_channels.ContainsKey(connection),
"AwaitData cannot be used once a proper channel for the connection has been constructed, as it does not support encryption.");
var tcs = new TaskCompletionSource<NetIncomingMessage>();
CancellationTokenRegistration reg = default;
if (cancellationToken != default)
@@ -380,5 +507,9 @@ namespace Robust.Shared.Network
}
}
private sealed class JoinRequest
{
public string Hash = default!;
}
}
}

View File

@@ -0,0 +1,44 @@
This file serves as documentation for network stuff
# Authentication Handshake
The client and server connect via Lidgren.Network.
This will be immediately obvious to you if you spent any time reading the code.
The game server can either require authentication, optionally allow authentication, or disable authentication entirely.
The packet exchange looks like this:
1. C->S `MsgLoginStart`
1. If client requests auth and server allows/requires auth:
2. S->C `MsgEncryptionRequest`
3. (client auth)
4. C->S `MsgEncryptionResponse`
5. (server auth, both enable encryption)
2. S->C `MsgLoginSuccess`
<small><small>Yes this is literally taken from [Minecraft's authentication protocol](https://wiki.vg/Protocol_Encryption) </small></small>
Note that the S->C packet AFTER `MsgLoginStart` is preceded with a bool (+pad) to indicate whether auth is being done or not. None of the net messages mentioned here are sent as "regular" net messages. They are used as containers for the write/read logic only. Barring the exception mentioned just now, they are read/written directory from the Lidgren data message instead of with a preceding string table ID.
A more detailed overview is here:
First the client sends `MsgLoginStart`. This contains the client's username, whether it wants to authenticate, and whether it needs the server's public RSA key sent (when authenticating && it doesn't have it yet from the launcher).
The server can then choose to do block the client, let the client authenticate, or let the client in as guest. If it lets the client in as guest it skips straight to sending `MsgLoginSuccess` (see below). Otherwise it will send an `MsgEncryptionRequest` to the client to initiate authentication.
`MsgEncryptionRequest` contains a random verify token sent by the server, as well as the server's public RSA key (if requested).
When the client receives `MsgEncryptionRequest`, it will generate a 32-byte random secret. It will then generate an SHA-256 hash of this secret and the server's public key. This hash is POSTed to `api/session/join` (along with login token in `Authorization` header) on the auth server. The shared secret and verify token are separately encrypted with the server's RSA key, then sent along with the client's account GUID to the server in `MsgEncryptionResponse`.
The server will then decrypt the verify token and shared secret with its private RSA key. If the verify token does not match then drop the client (to check if the client is using the correct key). Then the server will generate the same hash as mentioned earlier and GET it to `api/session/hasJoined?hash=<hash>&userId=<userId>` to check if the user did indeed authenticate correctly. And also gets the user's username and GUID again because why not.
From this point on, if authenticating, all messages sent between client and server will be AES encrypted with the shared secret generated earlier.
Then the server shall reply with `MsgLoginSuccess` with the assigned username/userID if login is successful.
I think that was everything.
Oh yeah, the server generates a new 2048-bit RSA key every startup and exposes it via its status API on `/info`.
This is a rough outline. If you want complete gritty details just check the damn code.

View File

@@ -0,0 +1,278 @@
using System;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Security.Cryptography;
using System.Threading.Tasks;
using Lidgren.Network;
using Newtonsoft.Json;
using Robust.Shared.Interfaces.Network;
using Robust.Shared.Log;
using Robust.Shared.Network.Messages;
using Robust.Shared.Utility;
using UsernameHelpers = Robust.Shared.AuthLib.UsernameHelpers;
namespace Robust.Shared.Network
{
partial class NetManager
{
private const int RsaKeySize = 2048;
private const int VerifyTokenSize = 4; // Literally just what MC does idk.
private RSA? _authRsaPrivateKey;
public byte[]? RsaPublicKey { get; private set; }
public AuthMode Auth { get; private set; }
public Func<string, Task<NetUserId?>>? AssignUserIdCallback { get; set; }
public IServerNetManager.NetApprovalDelegate? HandleApprovalCallback { get; set; }
private void SAGenerateRsaKeys()
{
_authRsaPrivateKey = RSA.Create(RsaKeySize);
RsaPublicKey = _authRsaPrivateKey.ExportRSAPublicKey();
/*
Logger.DebugS("auth", "Private RSA key is {0}",
Convert.ToBase64String(_authRsaPrivateKey.ExportRSAPrivateKey()));
*/
Logger.DebugS("auth", "Public RSA key is {0}", Convert.ToBase64String(RsaPublicKey));
}
private async void HandleHandshake(NetPeerData peer, NetConnection connection)
{
try
{
var incPacket = await AwaitData(connection);
var msgLogin = new MsgLoginStart();
msgLogin.ReadFromBuffer(incPacket);
var ip = connection.RemoteEndPoint.Address;
var isLocal = IPAddress.IsLoopback(ip) && _config.GetCVar<bool>("auth.allowlocal");
var canAuth = msgLogin.CanAuth;
var needPk = msgLogin.NeedPubKey;
var authServer = _config.GetSecureCVar<string>("auth.server");
if (Auth == AuthMode.Required && !isLocal)
{
if (!canAuth)
{
connection.Disconnect("Connecting to this server requires authentication");
return;
}
}
NetEncryption? encryption = null;
NetUserId userId;
string userName;
LoginType type;
var padSuccessMessage = true;
if (canAuth && Auth != AuthMode.Disabled)
{
var verifyToken = new byte[4];
RandomNumberGenerator.Fill(verifyToken);
var msgEncReq = new MsgEncryptionRequest
{
PublicKey = needPk ? RsaPublicKey : Array.Empty<byte>(),
VerifyToken = verifyToken
};
var outMsgEncReq = peer.Peer.CreateMessage();
outMsgEncReq.Write(false);
outMsgEncReq.WritePadBits();
msgEncReq.WriteToBuffer(outMsgEncReq);
peer.Peer.SendMessage(outMsgEncReq, connection, NetDeliveryMethod.ReliableOrdered);
incPacket = await AwaitData(connection);
var msgEncResponse = new MsgEncryptionResponse();
msgEncResponse.ReadFromBuffer(incPacket);
var verifyTokenCheck = _authRsaPrivateKey!.Decrypt(
msgEncResponse.VerifyToken,
RSAEncryptionPadding.OaepSHA256);
var sharedSecret = _authRsaPrivateKey!.Decrypt(
msgEncResponse.SharedSecret,
RSAEncryptionPadding.OaepSHA256);
if (!verifyToken.SequenceEqual(verifyTokenCheck))
{
connection.Disconnect("Verify token is invalid");
return;
}
encryption = new NetAESEncryption(peer.Peer, sharedSecret, 0, sharedSecret.Length);
var authHashBytes = MakeAuthHash(sharedSecret, RsaPublicKey!);
var authHash = Base64Helpers.ConvertToBase64Url(authHashBytes);
var client = new HttpClient();
var url = $"{authServer}api/session/hasJoined?hash={authHash}&userId={msgEncResponse.UserId}";
var joinedResp = await client.GetAsync(url);
joinedResp.EnsureSuccessStatusCode();
var joinedRespJson = JsonConvert.DeserializeObject<HasJoinedResponse>(
await joinedResp.Content.ReadAsStringAsync());
if (!joinedRespJson.IsValid)
{
connection.Disconnect("Failed to validate login");
return;
}
userId = new NetUserId(joinedRespJson.UserData!.UserId);
userName = joinedRespJson.UserData.UserName;
padSuccessMessage = false;
type = LoginType.LoggedIn;
}
else
{
var reqUserName = msgLogin.UserName;
if (!UsernameHelpers.IsNameValid(reqUserName, out var reason))
{
connection.Disconnect($"Username is invalid ({reason.ToText()}).");
return;
}
// If auth is set to "optional" we need to avoid conflicts between real accounts and guests,
// so we explicitly prefix guests.
var origName = Auth == AuthMode.Disabled
? reqUserName
: (isLocal ? $"localhost@{reqUserName}" : $"guest@{reqUserName}");
var name = origName;
var iterations = 1;
while (_assignedUsernames.ContainsKey(name))
{
// This is shit but I don't care.
name = $"{origName}_{++iterations}";
}
userName = name;
(userId, type) = await AssignUserIdAsync(name);
}
var endPoint = connection.RemoteEndPoint;
var connect = await OnConnecting(endPoint, userId, userName, type);
if (connect.IsDenied)
{
connection.Disconnect($"Connection denied: {connect.DenyReason}");
return;
}
// Well they're in. Kick a connected client with the same GUID if we have to.
if (_assignedUserIds.TryGetValue(userId, out var existing))
{
existing.Disconnect("Another connection has been made with your account.");
// Have to wait until they're properly off the server to avoid any collisions.
await AwaitDisconnectAsync(existing);
}
var msg = peer.Peer.CreateMessage();
var msgResp = new MsgLoginSuccess
{
UserId = userId.UserId,
UserName = userName,
Type = type
};
if (padSuccessMessage)
{
msg.Write(true);
msg.WritePadBits();
}
msgResp.WriteToBuffer(msg);
encryption?.Encrypt(msg);
peer.Peer.SendMessage(msg, connection, NetDeliveryMethod.ReliableOrdered);
Logger.InfoS("net",
"Approved {ConnectionEndpoint} with username {Username} user ID {userId} into the server",
connection.RemoteEndPoint, userName, userId);
// Handshake complete!
HandleInitialHandshakeComplete(peer, connection, userId, userName, encryption, type);
}
catch (ClientDisconnectedException)
{
Logger.InfoS("net",
$"Peer {NetUtility.ToHexString(connection.RemoteUniqueIdentifier)} disconnected while handshake was in-progress.");
}
catch (Exception e)
{
connection.Disconnect("Unknown server error occured during handshake.");
Logger.ErrorS("net", "Exception during handshake with peer {0}:\n{1}",
NetUtility.ToHexString(connection.RemoteUniqueIdentifier), e);
}
}
private async Task<(NetUserId, LoginType)> AssignUserIdAsync(string username)
{
if (AssignUserIdCallback == null)
{
goto unassigned;
}
var assigned = await AssignUserIdCallback(username);
if (assigned != null)
{
return (assigned.Value, LoginType.GuestAssigned);
}
unassigned:
// Just generate a random new GUID.
var uid = new NetUserId(Guid.NewGuid());
return (uid, LoginType.Guest);
}
private Task AwaitDisconnectAsync(NetConnection connection)
{
var tcs = new TaskCompletionSource<object?>();
_awaitingDisconnect.Add(connection, tcs);
return tcs.Task;
}
private async void HandleApproval(NetIncomingMessage message)
{
// TODO: Maybe preemptively refuse connections here in some cases?
if (message.SenderConnection.Status != NetConnectionStatus.RespondedAwaitingApproval)
{
// This can happen if the approval message comes in after the state changes to disconnected.
// In that case just ignore it.
return;
}
if (HandleApprovalCallback != null)
{
var approval = await HandleApprovalCallback(new NetApprovalEventArgs(message.SenderConnection));
if (!approval.IsApproved)
{
message.SenderConnection.Deny(approval.DenyReason);
return;
}
}
message.SenderConnection.Approve();
}
private sealed class HasJoinedResponse
{
#pragma warning disable 649
public bool IsValid;
public HasJoinedUserData? UserData;
public sealed class HasJoinedUserData
{
public string UserName = default!;
public Guid UserId = default!;
}
#pragma warning restore 649
}
}
}

View File

@@ -11,18 +11,16 @@ using System.Threading;
using System.Threading.Tasks;
using Lidgren.Network;
using Prometheus;
using Robust.Shared.Configuration;
using Robust.Shared.Interfaces.Configuration;
using Robust.Shared.Interfaces.Network;
using Robust.Shared.Interfaces.Serialization;
using Robust.Shared.IoC;
using Robust.Shared.Log;
using Robust.Shared.Utility;
using UsernameHelpers = Robust.Shared.AuthLib.UsernameHelpers;
namespace Robust.Shared.Network
{
/// <summary>
/// <summary>
/// Callback for registered NetMessages.
/// </summary>
/// <param name="message">The message received.</param>
@@ -39,6 +37,8 @@ namespace Robust.Shared.Network
/// </summary>
public partial class NetManager : IClientNetManager, IServerNetManager, IDisposable
{
internal const int AesKeyLength = 32;
[Dependency] private readonly IRobustSerializer _serializer = default!;
private static readonly Counter SentPacketsMetrics = Metrics.CreateCounter(
@@ -92,18 +92,21 @@ namespace Robust.Shared.Network
/// </summary>
private readonly Dictionary<NetConnection, NetChannel> _channels = new Dictionary<NetConnection, NetChannel>();
private readonly Dictionary<NetConnection, NetSessionId> _assignedSessions =
new Dictionary<NetConnection, NetSessionId>();
private readonly Dictionary<string, NetConnection> _assignedUsernames = new Dictionary<string, NetConnection>();
private readonly Dictionary<NetUserId, NetConnection> _assignedUserIds =
new Dictionary<NetUserId, NetConnection>();
// Used for processing incoming net messages.
private readonly NetMsgEntry[] _netMsgFunctions = new NetMsgEntry[256];
// Used for processing outgoing net messages.
private readonly Dictionary<Type, Func<NetMessage>> _blankNetMsgFunctions = new Dictionary<Type, Func<NetMessage>>();
private readonly Dictionary<Type, Func<NetMessage>> _blankNetMsgFunctions =
new Dictionary<Type, Func<NetMessage>>();
private readonly Dictionary<Type, long> _bandwidthUsage = new Dictionary<Type, long>();
[Dependency] private readonly IConfigurationManager _config = default!;
[Dependency] private readonly IConfigurationManagerInternal _config = default!;
/// <summary>
/// Holds lookup table for NetMessage.Id -> NetMessage.Type
@@ -123,11 +126,18 @@ namespace Robust.Shared.Network
// Client connect happens during status changed and such callbacks, so we need to defer deletion of these.
private readonly List<NetPeer> _toCleanNetPeers = new List<NetPeer>();
private readonly Dictionary<NetConnection, TaskCompletionSource<object?>> _awaitingDisconnect
= new Dictionary<NetConnection, TaskCompletionSource<object?>>();
/// <inheritdoc />
public int Port => _config.GetCVar<int>("net.port");
public bool IsAuthEnabled => _config.GetCVar<bool>("auth.enabled");
public IReadOnlyDictionary<Type, long> MessageBandwidthUsage => _bandwidthUsage;
private NetEncryption? _clientEncryption;
/// <inheritdoc />
public bool IsServer { get; private set; }
@@ -223,6 +233,10 @@ namespace Robust.Shared.Network
IsServer = isServer;
_config.OnValueChanged(CVars.NetVerbose, NetVerboseChanged);
if (isServer)
{
_config.OnValueChanged(CVars.AuthMode, i => Auth = (AuthMode) i, invokeImmediately: true);
}
#if DEBUG
_config.OnValueChanged(CVars.NetFakeLoss, _fakeLossChanged);
_config.OnValueChanged(CVars.NetFakeLagMin, _fakeLagMinChanged);
@@ -230,17 +244,20 @@ namespace Robust.Shared.Network
_config.OnValueChanged(CVars.NetFakeDuplicates, FakeDuplicatesChanged);
#endif
_strings.Initialize(() =>
{
Logger.InfoS("net","Message string table loaded.");
}, UpdateNetMessageFunctions);
_strings.Initialize(() => { Logger.InfoS("net", "Message string table loaded."); },
UpdateNetMessageFunctions);
_serializer.ClientHandshakeComplete += () =>
{
Logger.InfoS("net","Client completed serializer handshake.");
Logger.InfoS("net", "Client completed serializer handshake.");
OnConnected(ServerChannel!);
};
_initialized = true;
if (IsServer)
{
SAGenerateRsaKeys();
}
}
private void UpdateNetMessageFunctions(MsgStringTableEntries.Entry[] entries)
@@ -284,7 +301,8 @@ namespace Robust.Shared.Network
var config = _getBaseNetPeerConfig();
config.LocalAddress = address;
config.Port = Port;
config.EnableMessageType(NetIncomingMessageType.ConnectionApproval);
// Disabled for now since we aren't doing anything with the connection approval stuff.
// config.EnableMessageType(NetIncomingMessageType.ConnectionApproval);
if (address.AddressFamily == AddressFamily.InterNetworkV6 && dualStack)
{
@@ -389,6 +407,7 @@ namespace Robust.Shared.Network
case NetIncomingMessageType.ConnectionApproval:
HandleApproval(msg);
recycle = false;
break;
case NetIncomingMessageType.Data:
@@ -460,6 +479,7 @@ namespace Robust.Shared.Network
{
Disconnect?.Invoke(this, new NetDisconnectedArgs(ServerChannel, reason));
}
Shutdown(reason);
}
@@ -479,9 +499,11 @@ namespace Robust.Shared.Network
if (IsServer)
{
netConfig.SetMessageTypeEnabled(NetIncomingMessageType.ConnectionApproval, true);
netConfig.MaximumConnections = _config.GetCVar(CVars.GameMaxPlayers);
}
#if DEBUG
//Simulate Latency
netConfig.SimulatedLoss = _config.GetCVar<float>("net.fakeloss");
@@ -602,102 +624,34 @@ namespace Robust.Shared.Network
HandleDisconnect(peer, sender, reason);
}
if (_awaitingDisconnect.TryGetValue(sender, out var tcs))
{
tcs.TrySetResult(null);
}
break;
}
}
private void HandleApproval(NetIncomingMessage message)
private async void HandleInitialHandshakeComplete(NetPeerData peer,
NetConnection sender,
NetUserId userId,
string userName,
NetEncryption? encryption,
LoginType loginType)
{
// TODO: Maybe preemptively refuse connections here in some cases?
if (message.SenderConnection.Status != NetConnectionStatus.RespondedAwaitingApproval)
{
// This can happen if the approval message comes in after the state changes to disconnected.
// In that case just ignore it.
return;
}
message.SenderConnection.Approve();
}
private async void HandleHandshake(NetPeerData peer, NetConnection connection)
{
string requestedUsername;
try
{
var userNamePacket = await AwaitData(connection);
requestedUsername = userNamePacket.ReadString();
}
catch (ClientDisconnectedException)
{
return;
}
if (!UsernameHelpers.IsNameValid(requestedUsername, out var reason))
{
connection.Disconnect($"Username is invalid ({reason.ToText()}).");
return;
}
var endPoint = connection.RemoteEndPoint;
var name = requestedUsername;
var origName = name;
var iterations = 1;
while (_assignedSessions.Values.Any(u => u.Username == name))
{
// This is shit but I don't care.
name = $"{origName}_{++iterations}";
}
var session = new NetSessionId(name);
if (OnConnecting(endPoint, session))
{
_assignedSessions.Add(connection, session);
var msg = connection.Peer.CreateMessage();
msg.Write(name);
connection.Peer.SendMessage(msg, connection, NetDeliveryMethod.ReliableOrdered);
}
else
{
connection.Disconnect("Sorry, denied. Why? Couldn't tell you, I didn't implement a deny reason.");
return;
}
NetIncomingMessage okMsg;
try
{
okMsg = await AwaitData(connection);
}
catch (ClientDisconnectedException)
{
return;
}
if (okMsg.ReadString() != "ok")
{
connection.Disconnect("You should say ok.");
return;
}
Logger.InfoS("net", "Approved {ConnectionEndpoint} with username {Username} into the server",
connection.RemoteEndPoint, session);
// Handshake complete!
HandleInitialHandshakeComplete(peer, connection);
}
private async void HandleInitialHandshakeComplete(NetPeerData peer, NetConnection sender)
{
var session = _assignedSessions[sender];
var channel = new NetChannel(this, sender, session);
var channel = new NetChannel(this, sender, userId, userName, loginType);
_assignedUserIds.Add(userId, sender);
_assignedUsernames.Add(userName, sender);
_channels.Add(sender, channel);
peer.AddChannel(channel);
channel.Encryption = encryption;
_strings.SendFullTable(channel);
try
{
await Task.Delay(1000);
await _serializer.Handshake(channel);
}
catch (TaskCanceledException)
@@ -715,8 +669,10 @@ namespace Robust.Shared.Network
{
var channel = _channels[connection];
Logger.InfoS("net", "{ConnectionEndpoint}: Disconnected ({DisconnectReason})", channel.RemoteEndPoint, reason);
_assignedSessions.Remove(connection);
Logger.InfoS("net", "{ConnectionEndpoint}: Disconnected ({DisconnectReason})", channel.RemoteEndPoint,
reason);
_assignedUsernames.Remove(channel.UserName);
_assignedUserIds.Remove(channel.UserId);
OnDisconnected(channel, reason);
_channels.Remove(connection);
@@ -778,7 +734,14 @@ namespace Robust.Shared.Network
return true;
}
var channel = GetChannel(msg.SenderConnection);
var channel = _channels[msg.SenderConnection];
var encryption = IsServer ? channel.Encryption : _clientEncryption;
if (encryption != null)
{
msg.Decrypt(encryption);
}
var id = msg.ReadByte();
@@ -797,7 +760,7 @@ namespace Robust.Shared.Network
var instance = entry.CreateFunction(channel);
instance.MsgChannel = channel;
#if DEBUG
#if DEBUG
if (!_bandwidthUsage.TryGetValue(type, out var bandwidth))
{
@@ -806,7 +769,7 @@ namespace Robust.Shared.Network
_bandwidthUsage[type] = bandwidth + msg.LengthBytes;
#endif
#endif
try
{
@@ -837,6 +800,7 @@ namespace Robust.Shared.Network
Logger.ErrorS("net",
$"{msg.SenderConnection.RemoteEndPoint}: exception in message handler for {type.Name}:\n{e}");
}
return true;
}
@@ -861,7 +825,8 @@ namespace Robust.Shared.Network
DebugTools.AssertNotNull(constructor);
var dynamicMethod = new DynamicMethod($"_netMsg<>{name}", typeof(NetMessage), new[]{typeof(INetChannel)}, packetType, false);
var dynamicMethod = new DynamicMethod($"_netMsg<>{name}", typeof(NetMessage), new[] {typeof(INetChannel)},
packetType, false);
dynamicMethod.DefineParameter(1, ParameterAttributes.In, "channel");
@@ -920,7 +885,8 @@ namespace Robust.Shared.Network
DebugTools.AssertNotNull(constructor);
var dynamicMethod = new DynamicMethod($"_netMsg<>{type.Name}", typeof(NetMessage), Array.Empty<Type>(), type, false);
var dynamicMethod = new DynamicMethod($"_netMsg<>{type.Name}", typeof(NetMessage), Array.Empty<Type>(),
type, false);
var gen = dynamicMethod.GetILGenerator();
gen.Emit(OpCodes.Ldnull);
gen.Emit(OpCodes.Newobj, constructor);
@@ -952,16 +918,9 @@ namespace Robust.Shared.Network
if (!IsConnected)
return;
foreach (var peer in _netPeers)
foreach (var channel in _channels.Values)
{
var packet = BuildMessage(message, peer.Peer);
var method = message.DeliveryMethod;
if (peer.Channels.Count == 0)
{
continue;
}
peer.Peer.SendMessage(packet, peer.ConnectionsWithChannels, method, 0);
ServerSendMessage(message, channel);
}
}
@@ -974,6 +933,11 @@ namespace Robust.Shared.Network
var peer = channel.Connection.Peer;
var packet = BuildMessage(message, peer);
if (channel.Encryption != null)
{
packet.Encrypt(channel.Encryption);
}
var method = message.DeliveryMethod;
peer.SendMessage(packet, channel.Connection, method);
}
@@ -1006,6 +970,11 @@ namespace Robust.Shared.Network
var peer = _netPeers[0];
var packet = BuildMessage(message, peer.Peer);
var method = message.DeliveryMethod;
if (_clientEncryption != null)
{
packet.Encrypt(_clientEncryption);
}
peer.Peer.SendMessage(packet, peer.ConnectionsWithChannels[0], method);
}
@@ -1013,31 +982,45 @@ namespace Robust.Shared.Network
#region Events
protected virtual bool OnConnecting(IPEndPoint ip, NetSessionId sessionId)
private async Task<NetConnectingArgs> OnConnecting(
IPEndPoint ip,
NetUserId userId,
string userName,
LoginType loginType)
{
var args = new NetConnectingArgs(sessionId, ip);
Connecting?.Invoke(this, args);
return !args.Deny;
var args = new NetConnectingArgs(userId, ip, userName, loginType);
foreach (var conn in _connectingEvent)
{
await conn(args);
}
return args;
}
protected virtual void OnConnectFailed(string reason)
private void OnConnectFailed(string reason)
{
var args = new NetConnectFailArgs(reason);
ConnectFailed?.Invoke(this, args);
}
protected virtual void OnConnected(INetChannel channel)
private void OnConnected(INetChannel channel)
{
Connected?.Invoke(this, new NetChannelArgs(channel));
}
protected virtual void OnDisconnected(INetChannel channel, string reason)
private void OnDisconnected(INetChannel channel, string reason)
{
Disconnect?.Invoke(this, new NetDisconnectedArgs(channel, reason));
}
private readonly List<Func<NetConnectingArgs, Task>> _connectingEvent
= new List<Func<NetConnectingArgs, Task>>();
/// <inheritdoc />
public event EventHandler<NetConnectingArgs>? Connecting;
public event Func<NetConnectingArgs, Task> Connecting
{
add => _connectingEvent.Add(value);
remove => _connectingEvent.Remove(value);
}
/// <inheritdoc />
public event EventHandler<NetConnectFailArgs>? ConnectFailed;
@@ -1075,7 +1058,9 @@ namespace Robust.Shared.Network
private class NetPeerData
{
public readonly NetPeer Peer;
public readonly List<NetChannel> Channels = new List<NetChannel>();
// So that we can do ServerSendToAll without a list copy.
public readonly List<NetConnection> ConnectionsWithChannels = new List<NetConnection>();

View File

@@ -99,9 +99,9 @@ namespace Robust.Shared.Network
case MsgGroups.Entity:
return NetDeliveryMethod.Unreliable;
case MsgGroups.Core:
case MsgGroups.String:
case MsgGroups.Command:
return NetDeliveryMethod.ReliableUnordered;
case MsgGroups.String:
case MsgGroups.EntityEvent:
return NetDeliveryMethod.ReliableOrdered;
default:

View File

@@ -1,3 +1,4 @@
using System;
using Lidgren.Network;
using Robust.Shared.GameObjects;
using Robust.Shared.Map;
@@ -55,5 +56,19 @@ namespace Robust.Shared.Network
{
message.Write(tick.Value);
}
public static Guid ReadGuid(this NetIncomingMessage message)
{
Span<byte> span = stackalloc byte[16];
message.ReadBytes(span);
return new Guid(span);
}
public static void Write(this NetOutgoingMessage message, Guid guid)
{
Span<byte> span = stackalloc byte[16];
guid.TryWriteBytes(span);
message.Write(span);
}
}
}

View File

@@ -1,50 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Robust.Shared.Serialization;
namespace Robust.Shared.Network
{
[Serializable, NetSerializable]
public struct NetSessionId : IEquatable<NetSessionId>
{
public readonly string Username;
public NetSessionId(string name)
{
Username = name;
}
public override bool Equals(object? obj)
{
return obj is NetSessionId && Equals((NetSessionId)obj);
}
public bool Equals(NetSessionId other)
{
return Username == other.Username;
}
public override int GetHashCode()
{
return -182246463 + EqualityComparer<string>.Default.GetHashCode(Username);
}
public override string ToString()
{
return Username;
}
public static bool operator ==(NetSessionId id1, NetSessionId id2)
{
return id1.Equals(id2);
}
public static bool operator !=(NetSessionId id1, NetSessionId id2)
{
return !(id1 == id2);
}
}
}

View File

@@ -0,0 +1,46 @@
using System;
using Robust.Shared.Serialization;
namespace Robust.Shared.Network
{
[Serializable, NetSerializable]
public struct NetUserId : IEquatable<NetUserId>
{
public readonly Guid UserId;
public NetUserId(Guid userId)
{
UserId = userId;
}
public override bool Equals(object? obj)
{
return obj is NetUserId id && Equals(id);
}
public bool Equals(NetUserId other)
{
return UserId == other.UserId;
}
public override int GetHashCode()
{
return UserId.GetHashCode();
}
public override string ToString()
{
return UserId.ToString();
}
public static bool operator ==(NetUserId id1, NetUserId id2)
{
return id1.Equals(id2);
}
public static bool operator !=(NetUserId id1, NetUserId id2)
{
return !(id1 == id2);
}
}
}

View File

@@ -1,4 +1,5 @@
using Robust.Shared.Network;
using System;
using Robust.Shared.Network;
namespace Robust.Shared.Players
{
@@ -10,7 +11,7 @@ namespace Robust.Shared.Players
/// <summary>
/// The UID of this session.
/// </summary>
NetSessionId SessionId { get; }
NetUserId UserId { get; }
/// <summary>
/// Current name of this player.

View File

@@ -227,7 +227,7 @@ namespace Robust.Shared.Serialization
if (_incompleteHandshakes.TryGetValue(channel, out var handshake))
{
var tcs = handshake.Tcs;
LogSzr.Debug($"Cancelling handshake for disconnected client {channel.SessionId}");
LogSzr.Debug($"Cancelling handshake for disconnected client {channel.UserId}");
tcs.SetCanceled();
}
@@ -328,7 +328,7 @@ namespace Robust.Shared.Serialization
DebugTools.Assert(_dict.Locked);
var channel = msgMapStr.MsgChannel;
LogSzr.Debug($"Received handshake from {channel.SessionId}.");
LogSzr.Debug($"Received handshake from {channel.UserName}.");
if (!_incompleteHandshakes.TryGetValue(channel, out var handshake))
{
@@ -338,7 +338,7 @@ namespace Robust.Shared.Serialization
if (!msgMapStr.NeedsStrings)
{
LogSzr.Debug($"Completing handshake with {channel.SessionId}.");
LogSzr.Debug($"Completing handshake with {channel.UserName}.");
handshake.Tcs.SetResult(null);
_incompleteHandshakes.Remove(channel);
@@ -356,7 +356,7 @@ namespace Robust.Shared.Serialization
var strings = _net.CreateNetMessage<MsgMapStrStrings>();
strings.Package = _mappedStringsPackage;
LogSzr.Debug(
$"Sending {_mappedStringsPackage!.Length} bytes sized mapped strings package to {channel.SessionId}.");
$"Sending {_mappedStringsPackage!.Length} bytes sized mapped strings package to {channel.UserName}.");
_net.ServerSendMessage(strings, channel);
}

View File

@@ -33,6 +33,7 @@ namespace Robust.Shared
{
IoCManager.Register<IComponentManager, ComponentManager>();
IoCManager.Register<IConfigurationManager, ConfigurationManager>();
IoCManager.Register<IConfigurationManagerInternal, ConfigurationManager>();
IoCManager.Register<IDynamicTypeFactory, DynamicTypeFactory>();
IoCManager.Register<IEntitySystemManager, EntitySystemManager>();
IoCManager.Register<IGameTiming, GameTiming>();

View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Threading.Channels;
using System.Threading.Tasks;
using Robust.Shared.Interfaces.Network;
using Robust.Shared.Network;
using Robust.Shared.Utility;
@@ -97,24 +98,34 @@ namespace Robust.UnitTesting
{
DebugTools.Assert(IsServer);
var writer = connect.ChannelWriter;
var uid = _genConnectionUid();
var sessionId = new NetSessionId($"integration_{uid}");
var connectArgs =
new NetConnectingArgs(sessionId, new IPEndPoint(IPAddress.IPv6Loopback, 0));
Connecting?.Invoke(this, connectArgs);
if (connectArgs.Deny)
async void DoConnect()
{
writer.TryWrite(new DeniedConnectMessage());
continue;
var writer = connect.ChannelWriter;
var uid = _genConnectionUid();
var sessionId = new NetUserId(Guid.NewGuid());
var userName = $"integration_{uid}";
var args = await OnConnecting(
new IPEndPoint(IPAddress.IPv6Loopback, 0),
sessionId,
userName,
LoginType.Guest);
if (args.IsDenied)
{
writer.TryWrite(new DeniedConnectMessage());
return;
}
writer.TryWrite(new ConfirmConnectMessage(uid, sessionId, userName));
var channel = new IntegrationNetChannel(this, connect.ChannelWriter, uid, sessionId,
connect.Uid, userName);
_channels.Add(uid, channel);
Connected?.Invoke(this, new NetChannelArgs(channel));
}
writer.TryWrite(new ConfirmConnectMessage(uid, sessionId));
var channel = new IntegrationNetChannel(this, connect.ChannelWriter, uid, sessionId, connect.Uid);
_channels.Add(uid, channel);
Connected?.Invoke(this, new NetChannelArgs(channel));
DoConnect();
break;
}
@@ -180,7 +191,7 @@ namespace Robust.UnitTesting
DebugTools.Assert(IsClient);
var channel = new IntegrationNetChannel(this, NextConnectChannel!, _clientConnectingUid,
confirm.SessionId, confirm.AssignedUid);
confirm.UserId, confirm.AssignedUid, confirm.AssignedName);
_channels.Add(channel.ConnectionUid, channel);
@@ -194,6 +205,20 @@ namespace Robust.UnitTesting
}
}
private async Task<NetConnectingArgs> OnConnecting(
IPEndPoint ip,
NetUserId userId,
string userName,
LoginType loginType)
{
var args = new NetConnectingArgs(userId, ip, userName, loginType);
foreach (var conn in _connectingEvent)
{
await conn(args);
}
return args;
}
public void ServerSendToAll(NetMessage message)
{
DebugTools.Assert(IsServer);
@@ -222,7 +247,16 @@ namespace Robust.UnitTesting
}
}
public event EventHandler<NetConnectingArgs>? Connecting;
private readonly List<Func<NetConnectingArgs, Task>> _connectingEvent
= new List<Func<NetConnectingArgs, Task>>();
public event Func<NetConnectingArgs, Task> Connecting
{
add => _connectingEvent.Add(value);
remove => _connectingEvent.Remove(value);
}
public event EventHandler<NetChannelArgs>? Connected;
public event EventHandler<NetDisconnectedArgs>? Disconnect;
@@ -246,6 +280,11 @@ namespace Robust.UnitTesting
return (T) Activator.CreateInstance(typeof(T), (INetChannel?) null)!;
}
public byte[]? RsaPublicKey => null;
public AuthMode Auth => AuthMode.Disabled;
public Func<string, Task<NetUserId?>>? AssignUserIdCallback { get; set; }
public IServerNetManager.NetApprovalDelegate? HandleApprovalCallback { get; set; }
public void DisconnectChannel(INetChannel channel, string reason)
{
channel.Disconnect(reason);
@@ -253,6 +292,7 @@ namespace Robust.UnitTesting
INetChannel IClientNetManager.ServerChannel => ServerChannel;
public ClientConnectionState ClientConnectState => ClientConnectionState.NotConnecting;
public event Action<ClientConnectionState>? ClientConnectStateChanged
{
add { }
@@ -322,21 +362,24 @@ namespace Robust.UnitTesting
// TODO: Should this port value make sense?
public IPEndPoint RemoteEndPoint { get; } = new IPEndPoint(IPAddress.Loopback, 1212);
public NetSessionId SessionId { get; }
public NetUserId UserId { get; }
public string UserName { get; }
public LoginType AuthType => LoginType.Guest;
public short Ping => default;
public IntegrationNetChannel(IntegrationNetManager owner, ChannelWriter<object> otherChannel, int uid,
NetSessionId sessionId)
NetUserId userId, string userName)
{
_owner = owner;
ConnectionUid = uid;
SessionId = sessionId;
UserId = userId;
UserName = userName;
OtherChannel = otherChannel;
IsConnected = true;
}
public IntegrationNetChannel(IntegrationNetManager owner, ChannelWriter<object> otherChannel, int uid,
NetSessionId sessionId, int remoteUid) : this(owner, otherChannel, uid, sessionId)
NetUserId userId, int remoteUid, string userName) : this(owner, otherChannel, uid, userId, userName)
{
RemoteUid = uid;
}
@@ -371,14 +414,16 @@ namespace Robust.UnitTesting
private sealed class ConfirmConnectMessage
{
public ConfirmConnectMessage(int assignedUid, NetSessionId sessionId)
public ConfirmConnectMessage(int assignedUid, NetUserId userId, string assignedName)
{
AssignedUid = assignedUid;
SessionId = sessionId;
UserId = userId;
AssignedName = assignedName;
}
public int AssignedUid { get; }
public NetSessionId SessionId { get; }
public NetUserId UserId { get; }
public string AssignedName { get; }
}
private sealed class DeniedConnectMessage

View File

@@ -0,0 +1,23 @@
using NUnit.Framework;
using Robust.Shared.Configuration;
namespace Robust.UnitTesting.Shared.Configuration
{
internal sealed class ConfigurationManagerTest
{
[Test]
public void TestSecureCVar()
{
var cfg = new ConfigurationManager();
cfg.RegisterCVar("auth.token", "honk", CVar.SECURE);
Assert.That(() => cfg.GetCVar<string>("auth.token"), Throws.TypeOf<InvalidConfigurationException>());
Assert.That(() => cfg.GetCVarType("auth.token"), Throws.TypeOf<InvalidConfigurationException>());
Assert.That(() => cfg.SetCVar("auth.token", "foo"), Throws.TypeOf<InvalidConfigurationException>());
Assert.That(cfg.GetSecureCVar<string>("auth.token"), Is.EqualTo("honk"));
Assert.That(cfg.IsCVarRegistered("auth.token"), Is.False);
Assert.That(cfg.GetRegisteredCVars(), Does.Not.Contain("auth.token"));
}
}
}

View File

@@ -33,12 +33,16 @@ namespace Robust.UnitTesting.Shared.GameObjects
var container = new DependencyCollection();
container.Register<ILogManager, LogManager>();
container.Register<IConfigurationManager, ConfigurationManager>();
container.Register<IConfigurationManagerInternal, ConfigurationManager>();
container.Register<INetManager, NetManager>();
container.Register<IReflectionManager, ServerReflectionManager>();
container.Register<IRobustSerializer, RobustSerializer>();
container.Register<IRobustMappedStringSerializer, RobustMappedStringSerializer>();
container.BuildGraph();
container.Resolve<IConfigurationManager>().Initialize(true);
container.Resolve<IConfigurationManager>().LoadCVarsFromAssembly(typeof(IConfigurationManager).Assembly);
container.Resolve<IReflectionManager>().LoadAssemblies(AppDomain.CurrentDomain.GetAssemblyByName("Robust.Shared"));
IoCManager.InitThread(container);