Implement subscription ordering into EventBus. (#1823)

* Implement subscription ordering into EventBus.

* Topological helpers, rest of code uses shared topological sort, fix bugs.

* Fix tests.

Didn't realize that multi-subscriptions are allowed on input handlers like that??

* Improve and use topological sort helpers more.
This commit is contained in:
Pieter-Jan Briers
2021-06-14 21:34:30 +02:00
committed by GitHub
parent 33334d6f5c
commit fa0d4da6d1
12 changed files with 577 additions and 220 deletions

View File

@@ -106,8 +106,16 @@ namespace Robust.Shared.ContentPack
Logger.DebugS("res.mod", $"Verified assemblies in {checkerSw.ElapsedMilliseconds}ms");
}
var nodes = TopologicalSort.FromBeforeAfter(
files,
kv => kv.Key,
kv => kv.Value.Path,
_ => Array.Empty<string>(),
kv => kv.Value.references,
allowMissing: true); // missing refs would be non-content assemblies so allow that.
// Actually load them in the order they depend on each other.
foreach (var path in TopologicalSortModules(files))
foreach (var path in TopologicalSort.Sort(nodes))
{
Logger.DebugS("res.mod", $"Loading module: '{path}'");
try
@@ -136,38 +144,6 @@ namespace Robust.Shared.ContentPack
return true;
}
private static IEnumerable<ResourcePath> TopologicalSortModules(
IEnumerable<KeyValuePair<string, (ResourcePath Path, string[] references)>> modules)
{
var elems = modules.ToDictionary(
node => node.Key,
node => (node.Value.Path, refs: new HashSet<string>(node.Value.references)));
// Remove assembly references we aren't sorting for.
foreach (var (_, set) in elems.Values)
{
set.RemoveWhere(r => !elems.ContainsKey(r));
}
while (elems.Count > 0)
{
var elem = elems.FirstOrNull(x => x.Value.refs.Count == 0);
if (elem == null)
{
throw new InvalidOperationException(
"Found circular dependency in assembly dependency graph");
}
elems.Remove(elem.Value.Key);
foreach (var sElem in elems)
{
sElem.Value.refs.Remove(elem.Value.Key);
}
yield return elem.Value.Value.Path;
}
}
private static (string[] refs, string name) GetAssemblyReferenceData(Stream stream)
{
using var reader = new PEReader(stream);

View File

@@ -24,6 +24,15 @@ namespace Robust.Shared.GameObjects
void SubscribeEvent<T>(EventSource source, IEntityEventSubscriber subscriber,
EntityEventHandler<T> eventHandler) where T : notnull;
void SubscribeEvent<T>(
EventSource source,
IEntityEventSubscriber subscriber,
EntityEventHandler<T> eventHandler,
Type orderType,
Type[]? before=null,
Type[]? after=null)
where T : notnull;
/// <summary>
/// Unsubscribes all event handlers of a given type.
/// </summary>
@@ -137,7 +146,7 @@ namespace Robust.Shared.GameObjects
return;
// UnsubscribeEvent modifies _inverseEventSubscriptions, requires val to be cached
foreach (var (type, (source, originalHandler, handler)) in val.ToList())
foreach (var (type, (source, originalHandler, handler, _)) in val.ToList())
{
UnsubscribeEvent(source, type, originalHandler, handler, subscriber);
}
@@ -154,7 +163,36 @@ namespace Robust.Shared.GameObjects
}
/// <inheritdoc />
public void SubscribeEvent<T>(EventSource source, IEntityEventSubscriber subscriber, EntityEventHandler<T> eventHandler) where T : notnull
public void SubscribeEvent<T>(
EventSource source,
IEntityEventSubscriber subscriber,
EntityEventHandler<T> eventHandler)
where T : notnull
{
SubscribeEventCommon(source, subscriber, eventHandler, null);
}
public void SubscribeEvent<T>(
EventSource source,
IEntityEventSubscriber subscriber,
EntityEventHandler<T> eventHandler,
Type orderType,
Type[]? before=null,
Type[]? after=null)
where T : notnull
{
var order = new OrderingData(orderType, before, after);
SubscribeEventCommon(source, subscriber, eventHandler, order);
HandleOrderRegistration(typeof(T), order);
}
private void SubscribeEventCommon<T>(
EventSource source,
IEntityEventSubscriber subscriber,
EntityEventHandler<T> eventHandler,
OrderingData? order)
where T : notnull
{
if (source == EventSource.None)
throw new ArgumentOutOfRangeException(nameof(source));
@@ -166,7 +204,7 @@ namespace Robust.Shared.GameObjects
throw new ArgumentNullException(nameof(subscriber));
var eventType = typeof(T);
var subscriptionTuple = new Registration(source, eventHandler, ev => eventHandler((T) ev), eventHandler);
var subscriptionTuple = new Registration(source, eventHandler, ev => eventHandler((T) ev), eventHandler, order);
if (!_eventSubscriptions.TryGetValue(eventType, out var subscriptions))
_eventSubscriptions.Add(eventType, new List<Registration> {subscriptionTuple});
else if (!subscriptions.Any(p => p.Mask == source && p.Original == (Delegate) eventHandler))
@@ -184,7 +222,6 @@ namespace Robust.Shared.GameObjects
inverseSubscription
);
}
else if (!inverseSubscription.ContainsKey(eventType))
{
inverseSubscription.Add(eventType, subscriptionTuple);
@@ -290,15 +327,13 @@ namespace Robust.Shared.GameObjects
});
}
_awaitingMessages.Add(type, (source, reg, tcs));
return tcs.Task;
}
private void UnsubscribeEvent(EventSource source, Type eventType, Delegate originalHandler, EventHandler handler, IEntityEventSubscriber subscriber)
{
var tuple = new Registration(source, originalHandler, handler, originalHandler);
var tuple = new Registration(source, originalHandler, handler, originalHandler, null);
if (_eventSubscriptions.TryGetValue(eventType, out var subscriptions) && subscriptions.Contains(tuple))
subscriptions.Remove(tuple);
@@ -310,7 +345,11 @@ namespace Robust.Shared.GameObjects
{
var eventType = eventArgs.GetType();
if (_eventSubscriptions.TryGetValue(eventType, out var subs))
if (_orderedEvents.Contains(eventType))
{
ProcessSingleEventOrdered(source, eventArgs, eventType);
}
else if (_eventSubscriptions.TryGetValue(eventType, out var subs))
{
foreach (var handler in subs)
{
@@ -319,25 +358,20 @@ namespace Robust.Shared.GameObjects
}
}
if (_awaitingMessages.TryGetValue(eventType, out var awaiting))
{
var (mask, _, tcs) = awaiting;
if ((source & mask) != 0)
{
tcs.TrySetResult(eventArgs);
_awaitingMessages.Remove(eventType);
}
}
ProcessAwaitingMessages(source, eventArgs, eventType);
}
private void ProcessSingleEvent<T>(EventSource source, T eventArgs) where T : notnull
{
var eventType = typeof(T);
if (_eventSubscriptions.TryGetValue(eventType, out var subs))
if (_orderedEvents.Contains(eventType))
{
foreach (var (mask, originalHandler, _) in subs)
ProcessSingleEventOrdered(source, eventArgs, eventType);
}
else if (_eventSubscriptions.TryGetValue(eventType, out var subs))
{
foreach (var (mask, originalHandler, _, _) in subs)
{
if ((mask & source) != 0)
{
@@ -347,6 +381,13 @@ namespace Robust.Shared.GameObjects
}
}
ProcessAwaitingMessages(source, eventArgs, eventType);
}
// Generic here so we can avoid boxing alloc unless actually awaiting.
private void ProcessAwaitingMessages<T>(EventSource source, T eventArgs, Type eventType)
where T : notnull
{
if (_awaitingMessages.TryGetValue(eventType, out var awaiting))
{
var (mask1, _, tcs) = awaiting;
@@ -366,13 +407,20 @@ namespace Robust.Shared.GameObjects
public readonly Delegate Original;
public readonly EventHandler Handler;
public readonly OrderingData? Ordering;
public Registration(EventSource mask, Delegate original, EventHandler handler, object equalityToken)
public Registration(
EventSource mask,
Delegate original,
EventHandler handler,
object equalityToken,
OrderingData? ordering)
{
Mask = mask;
Original = original;
Handler = handler;
EqualityToken = equalityToken;
Ordering = ordering;
}
public bool Equals(Registration other)
@@ -403,11 +451,16 @@ namespace Robust.Shared.GameObjects
return !left.Equals(right);
}
public void Deconstruct(out EventSource mask, out Delegate originalHandler, out EventHandler handler)
public void Deconstruct(
out EventSource mask,
out Delegate originalHandler,
out EventHandler handler,
out OrderingData? order)
{
mask = Mask;
originalHandler = Original;
handler = Handler;
order = Ordering;
}
}
}

View File

@@ -14,6 +14,12 @@ namespace Robust.Shared.GameObjects
where TComp : IComponent
where TEvent : EntityEventArgs;
void SubscribeLocalEvent<TComp, TEvent>(
ComponentEventHandler<TComp, TEvent> handler,
Type orderType, Type[]? before=null, Type[]? after=null)
where TComp : IComponent
where TEvent : EntityEventArgs;
[Obsolete("Use the overload without the handler argument.")]
void UnsubscribeLocalEvent<TComp, TEvent>(ComponentEventHandler<TComp, TEvent> handler)
where TComp : IComponent
@@ -61,6 +67,12 @@ namespace Robust.Shared.GameObjects
public void RaiseLocalEvent<TEvent>(EntityUid uid, TEvent args, bool broadcast = true)
where TEvent : EntityEventArgs
{
if (_orderedEvents.Contains(typeof(TEvent)))
{
RaiseLocalOrdered(uid, args, broadcast);
return;
}
_eventTables.Dispatch(uid, typeof(TEvent), args);
// we also broadcast it so the call site does not have to.
@@ -76,7 +88,24 @@ namespace Robust.Shared.GameObjects
void EventHandler(EntityUid uid, IComponent comp, EntityEventArgs args)
=> handler(uid, (TComp) comp, (TEvent) args);
_eventTables.Subscribe(typeof(TComp), typeof(TEvent), EventHandler);
_eventTables.Subscribe(typeof(TComp), typeof(TEvent), EventHandler, null);
}
public void SubscribeLocalEvent<TComp, TEvent>(
ComponentEventHandler<TComp, TEvent> handler,
Type orderType,
Type[]? before=null,
Type[]? after=null)
where TComp : IComponent
where TEvent : EntityEventArgs
{
void EventHandler(EntityUid uid, IComponent comp, EntityEventArgs args)
=> handler(uid, (TComp) comp, (TEvent) args);
var orderData = new OrderingData(orderType, before, after);
_eventTables.Subscribe(typeof(TComp), typeof(TEvent), EventHandler, orderData);
HandleOrderRegistration(typeof(TEvent), orderData);
}
/// <inheritdoc />
@@ -105,7 +134,7 @@ namespace Robust.Shared.GameObjects
private Dictionary<EntityUid, Dictionary<Type, HashSet<Type>>> _eventTables;
// EventType -> CompType -> Handler
private Dictionary<Type, Dictionary<Type, DirectedEventHandler>> _subscriptions;
private Dictionary<Type, Dictionary<Type, (DirectedEventHandler handler, OrderingData? ordering)>> _subscriptions;
// prevents shitcode, get your subscriptions figured out before you start spawning entities
private bool _subscriptionLock;
@@ -148,25 +177,21 @@ namespace Robust.Shared.GameObjects
RemoveComponent(e.OwnerUid, e.Component.GetType());
}
public void Subscribe(Type compType, Type eventType, DirectedEventHandler handler)
public void Subscribe(Type compType, Type eventType, DirectedEventHandler handler, OrderingData? order)
{
if (_subscriptionLock)
throw new InvalidOperationException("Subscription locked.");
if (!_subscriptions.TryGetValue(compType, out var compSubs))
{
compSubs = new Dictionary<Type, DirectedEventHandler>();
compSubs = new Dictionary<Type, (DirectedEventHandler, OrderingData?)>();
_subscriptions.Add(compType, compSubs);
compSubs.Add(eventType, handler);
}
else
{
if (compSubs.ContainsKey(eventType))
throw new InvalidOperationException($"Duplicate Subscriptions for comp={compType.Name}, event={eventType.Name}");
compSubs.Add(eventType, handler);
}
if (compSubs.ContainsKey(eventType))
throw new InvalidOperationException($"Duplicate Subscriptions for comp={compType.Name}, event={eventType.Name}");
compSubs.Add(eventType, (handler, order));
}
public void Unsubscribe(Type compType, Type eventType)
@@ -245,15 +270,41 @@ namespace Robust.Shared.GameObjects
if(!_subscriptions.TryGetValue(compType, out var compSubs))
return;
if(!compSubs.TryGetValue(eventType, out var handler))
if(!compSubs.TryGetValue(eventType, out var sub))
return;
var (handler, _) = sub;
var component = _entMan.ComponentManager.GetComponent(euid, compType);
handler(euid, component, args);
}
}
public void CollectOrdered(
EntityUid euid,
Type eventType,
List<(EventHandler, OrderingData?)> found)
{
var eventTable = _eventTables[euid];
if(!eventTable.TryGetValue(eventType, out var subscribedComps))
return;
foreach (var compType in subscribedComps)
{
if(!_subscriptions.TryGetValue(compType, out var compSubs))
return;
if(!compSubs.TryGetValue(eventType, out var sub))
return;
var (handler, order) = sub;
var component = _entMan.ComponentManager.GetComponent(euid, compType);
found.Add((ev => handler(euid, component, (EntityEventArgs) ev), order));
}
}
public void DispatchComponent(EntityUid euid, IComponent component, Type eventType, EntityEventArgs args)
{
foreach (var type in GetReferences(component.GetType()))
@@ -261,9 +312,10 @@ namespace Robust.Shared.GameObjects
if (!_subscriptions.TryGetValue(type, out var compSubs))
continue;
if (!compSubs.TryGetValue(eventType, out var handler))
if (!compSubs.TryGetValue(eventType, out var sub))
continue;
var (handler, _) = sub;
handler(euid, component, args);
}
}

View File

@@ -0,0 +1,95 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Robust.Shared.Utility;
namespace Robust.Shared.GameObjects
{
internal partial class EntityEventBus
{
// TODO: Topological sort is currently done every time an event is emitted.
// This should be fine for low-volume stuff like interactions, but definitely not for anything high volume.
// Not sure if we could pre-cache the topological sort, here.
// Ordered event raising is slow so if this event has any ordering dependencies we use a slower path.
private readonly HashSet<Type> _orderedEvents = new();
private void ProcessSingleEventOrdered(EventSource source, object eventArgs, Type eventType)
{
var found = new List<(EventHandler, OrderingData?)>();
CollectBroadcastOrdered(source, eventType, found);
DispatchOrderedEvents(eventArgs, found);
}
private void CollectBroadcastOrdered(
EventSource source,
Type eventType,
List<(EventHandler, OrderingData?)> found)
{
if (!_eventSubscriptions.TryGetValue(eventType, out var subs))
return;
foreach (var handler in subs)
{
if ((handler.Mask & source) != 0)
found.Add((handler.Handler, handler.Ordering));
}
}
private void RaiseLocalOrdered<TEvent>(
EntityUid uid,
TEvent args,
bool broadcast)
where TEvent : EntityEventArgs
{
var found = new List<(EventHandler, OrderingData?)>();
if (broadcast)
CollectBroadcastOrdered(EventSource.Local, typeof(TEvent), found);
_eventTables.CollectOrdered(uid, typeof(TEvent), found);
DispatchOrderedEvents(args, found);
if (broadcast)
ProcessAwaitingMessages(EventSource.Local, args, typeof(TEvent));
}
private static void DispatchOrderedEvents(object eventArgs, List<(EventHandler, OrderingData?)> found)
{
var nodes = TopologicalSort.FromBeforeAfter(
found.Where(f => f.Item2 != null),
n => n.Item2!.OrderType,
n => n.Item1!,
n => n.Item2!.Before ?? Array.Empty<Type>(),
n => n.Item2!.After ?? Array.Empty<Type>(),
allowMissing: true);
foreach (var handler in TopologicalSort.Sort(nodes))
{
handler(eventArgs);
}
// Go over all handlers that don't have ordering so weren't included in the sort.
foreach (var (handler, orderData) in found)
{
if (orderData == null)
handler(eventArgs);
}
}
private void HandleOrderRegistration(Type eventType, OrderingData? data)
{
if (data == null)
return;
if (data.Before != null || data.After != null)
_orderedEvents.Add(eventType);
}
private sealed record OrderingData(Type OrderType, Type[]? Before, Type[]? After);
}
}

View File

@@ -7,16 +7,22 @@ namespace Robust.Shared.GameObjects
{
private List<SubBase>? _subscriptions;
protected void SubscribeNetworkEvent<T>(EntityEventHandler<T> handler)
// NOTE: EntityEventHandler<T> and EntitySessionEventHandler<T> CANNOT BE ORDERED BETWEEN EACH OTHER.
protected void SubscribeNetworkEvent<T>(
EntityEventHandler<T> handler,
Type[]? before=null, Type[]? after=null)
where T : notnull
{
EntityManager.EventBus.SubscribeEvent(EventSource.Network, this, handler);
EntityManager.EventBus.SubscribeEvent(EventSource.Network, this, handler, GetType(), before, after);
_subscriptions ??= new();
_subscriptions.Add(new SubBroadcast<T>(EventSource.Network));
}
protected void SubscribeNetworkEvent<T>(EntitySessionEventHandler<T> handler)
protected void SubscribeNetworkEvent<T>(
EntitySessionEventHandler<T> handler,
Type[]? before=null, Type[]? after=null)
where T : notnull
{
EntityManager.EventBus.SubscribeSessionEvent(EventSource.Network, this, handler);
@@ -25,16 +31,20 @@ namespace Robust.Shared.GameObjects
_subscriptions.Add(new SubBroadcast<EntitySessionMessage<T>>(EventSource.Network));
}
protected void SubscribeLocalEvent<T>(EntityEventHandler<T> handler)
protected void SubscribeLocalEvent<T>(
EntityEventHandler<T> handler,
Type[]? before=null, Type[]? after=null)
where T : notnull
{
EntityManager.EventBus.SubscribeEvent(EventSource.Local, this, handler);
EntityManager.EventBus.SubscribeEvent(EventSource.Local, this, handler, GetType(), before, after);
_subscriptions ??= new();
_subscriptions.Add(new SubBroadcast<T>(EventSource.Local));
}
protected void SubscribeLocalEvent<T>(EntitySessionEventHandler<T> handler)
protected void SubscribeLocalEvent<T>(
EntitySessionEventHandler<T> handler,
Type[]? before=null, Type[]? after=null)
where T : notnull
{
EntityManager.EventBus.SubscribeSessionEvent(EventSource.Local, this, handler);
@@ -57,8 +67,9 @@ namespace Robust.Shared.GameObjects
EntityManager.EventBus.UnsubscribeEvent<T>(EventSource.Local, this);
}
protected void SubscribeLocalEvent<TComp, TEvent>(ComponentEventHandler<TComp, TEvent> handler)
protected void SubscribeLocalEvent<TComp, TEvent>(
ComponentEventHandler<TComp, TEvent> handler,
Type[]? before=null, Type[]? after=null)
where TComp : IComponent
where TEvent : EntityEventArgs
{

View File

@@ -166,15 +166,15 @@ namespace Robust.Shared.GameObjects
Dictionary<Type, IEntitySystem>.ValueCollection systems,
Dictionary<Type, IEntitySystem> supertypeSystems)
{
var allNodes = new List<GraphNode<IEntitySystem>>();
var typeToNode = new Dictionary<Type, GraphNode<IEntitySystem>>();
var allNodes = new List<TopologicalSort.GraphNode<IEntitySystem>>();
var typeToNode = new Dictionary<Type, TopologicalSort.GraphNode<IEntitySystem>>();
foreach (var system in systems)
{
var node = new GraphNode<IEntitySystem>(system);
var node = new TopologicalSort.GraphNode<IEntitySystem>(system);
allNodes.Add(node);
typeToNode.Add(system.GetType(), node);
allNodes.Add(node);
}
foreach (var (type, system) in supertypeSystems)
@@ -183,52 +183,30 @@ namespace Robust.Shared.GameObjects
typeToNode[type] = node;
}
foreach (var node in allNodes)
foreach (var node in typeToNode.Values)
{
foreach (var after in node.System.UpdatesAfter)
foreach (var after in node.Value.UpdatesAfter)
{
var system = typeToNode[after];
node.DependsOn.Add(system);
system.Dependant.Add(node);
}
foreach (var before in node.System.UpdatesBefore)
foreach (var before in node.Value.UpdatesBefore)
{
var system = typeToNode[before];
system.DependsOn.Add(node);
node.Dependant.Add(system);
}
}
var order = TopologicalSort(allNodes).Select(p => p.System).ToArray();
var order = TopologicalSort.Sort(allNodes).ToArray();
var frameUpdate = order.Where(p => NeedsFrameUpdate(p.GetType()));
var update = order.Where(p => NeedsUpdate(p.GetType()));
return (frameUpdate, update);
}
internal static IEnumerable<GraphNode<T>> TopologicalSort<T>(IEnumerable<GraphNode<T>> nodes)
{
var elems = nodes.ToDictionary(node => node,
node => new HashSet<GraphNode<T>>(node.DependsOn));
while (elems.Count > 0)
{
var elem =
elems.FirstOrDefault(x => x.Value.Count == 0);
if (elem.Key == null)
{
throw new InvalidOperationException(
"Found circular dependency when resolving entity system update dependency graph");
}
elems.Remove(elem.Key);
foreach (var selem in elems)
{
selem.Value.Remove(elem.Key);
}
yield return elem.Key;
}
}
private static IEnumerable<Type> GetBaseTypes(Type type) {
if(type.BaseType == null) return type.GetInterfaces();
@@ -344,18 +322,6 @@ namespace Robust.Shared.GameObjects
return mFrameUpdate!.DeclaringType != typeof(EntitySystem);
}
[DebuggerDisplay("GraphNode: {" + nameof(System) + "}")]
internal sealed class GraphNode<T>
{
public readonly T System;
public readonly List<GraphNode<T>> DependsOn = new();
public GraphNode(T system)
{
System = system;
}
}
private struct UpdateReg
{
[ViewVariables] public IEntitySystem System;

View File

@@ -14,6 +14,7 @@ using Robust.Shared.Physics.Controllers;
using Robust.Shared.Physics.Dynamics;
using Robust.Shared.Reflection;
using Robust.Shared.Timing;
using Robust.Shared.Utility;
using DependencyAttribute = Robust.Shared.IoC.DependencyAttribute;
using Logger = Robust.Shared.Log.Logger;
@@ -114,45 +115,24 @@ namespace Robust.Shared.GameObjects
{
var reflectionManager = IoCManager.Resolve<IReflectionManager>();
var typeFactory = IoCManager.Resolve<IDynamicTypeFactory>();
var allControllerTypes = new List<Type>();
var instantiated = new List<VirtualController>();
foreach (var type in reflectionManager.GetAllChildren(typeof(VirtualController)))
{
if (type.IsAbstract) continue;
allControllerTypes.Add(type);
if (type.IsAbstract)
continue;
instantiated.Add(typeFactory.CreateInstance<VirtualController>(type));
}
var instantiated = new Dictionary<Type, VirtualController>();
var nodes = TopologicalSort.FromBeforeAfter(
instantiated,
c => c.GetType(),
c => c,
c => c.UpdatesBefore,
c => c.UpdatesAfter);
foreach (var type in allControllerTypes)
{
instantiated.Add(type, (VirtualController) typeFactory.CreateInstance(type));
}
// Build dependency graph, copied from EntitySystemManager *COUGH
var nodes = new Dictionary<Type, EntitySystemManager.GraphNode<VirtualController>>();
foreach (var (_, controller) in instantiated)
{
var node = new EntitySystemManager.GraphNode<VirtualController>(controller);
nodes[controller.GetType()] = node;
}
foreach (var (type, node) in nodes)
{
foreach (var before in instantiated[type].UpdatesBefore)
{
nodes[before].DependsOn.Add(node);
}
foreach (var after in instantiated[type].UpdatesAfter)
{
node.DependsOn.Add(nodes[after]);
}
}
_controllers = GameObjects.EntitySystemManager.TopologicalSort(nodes.Values).Select(c => c.System).ToList();
_controllers = TopologicalSort.Sort(nodes).ToList();
foreach (var controller in _controllers)
{

View File

@@ -2,6 +2,7 @@ using System;
using System.Collections.Generic;
using System.Linq;
using Robust.Shared.Log;
using Robust.Shared.Utility;
namespace Robust.Shared.Input.Binding
{
@@ -13,8 +14,8 @@ namespace Robust.Shared.Input.Binding
// handlers in the order they should be resolved for the given key function.
// internally we use a graph to construct this but we render it down to a flattened
// list so we don't need to do any graph traversal at query time
private Dictionary<BoundKeyFunction, List<InputCmdHandler>> _bindingsForKey =
new();
private Dictionary<BoundKeyFunction, List<InputCmdHandler>> _bindingsForKey = new();
private bool _graphDirty = false;
/// <inheritdoc />
public void Register<TOwner>(CommandBinds commandBinds)
@@ -41,12 +42,15 @@ namespace Robust.Shared.Input.Binding
_bindings.Add(new TypedCommandBind(owner, binding));
}
RebuildGraph();
_graphDirty = true;
}
/// <inheritdoc />
public IEnumerable<InputCmdHandler> GetHandlers(BoundKeyFunction function)
{
if (_graphDirty)
RebuildGraph();
if (_bindingsForKey.TryGetValue(function, out var handlers))
{
return handlers;
@@ -58,7 +62,8 @@ namespace Robust.Shared.Input.Binding
public void Unregister(Type owner)
{
_bindings.RemoveAll(binding => binding.ForType == owner);
RebuildGraph();
_graphDirty = true;
}
/// <inheritdoc />
@@ -67,7 +72,7 @@ namespace Robust.Shared.Input.Binding
Unregister(typeof(TOwner));
}
private void RebuildGraph()
internal void RebuildGraph()
{
_bindingsForKey.Clear();
@@ -77,6 +82,7 @@ namespace Robust.Shared.Input.Binding
}
_graphDirty = false;
}
private Dictionary<BoundKeyFunction, List<TypedCommandBind>> FunctionToBindings()
@@ -105,16 +111,16 @@ namespace Robust.Shared.Input.Binding
//TODO: Probably could be optimized if needed! Generally shouldn't be a big issue since there is a relatively
// tiny amount of bindings
List<GraphNode> allNodes = new();
Dictionary<Type,List<GraphNode>> typeToNode = new();
List<TopologicalSort.GraphNode<TypedCommandBind>> allNodes = new();
Dictionary<Type,List<TopologicalSort.GraphNode<TypedCommandBind>>> typeToNode = new();
// build the dict for quick lookup on type
foreach (var binding in bindingsForFunction)
{
if (!typeToNode.ContainsKey(binding.ForType))
{
typeToNode[binding.ForType] = new List<GraphNode>();
typeToNode[binding.ForType] = new List<TopologicalSort.GraphNode<TypedCommandBind>>();
}
var newNode = new GraphNode(binding);
var newNode = new TopologicalSort.GraphNode<TypedCommandBind>(binding);
typeToNode[binding.ForType].Add(newNode);
allNodes.Add(newNode);
}
@@ -122,7 +128,7 @@ namespace Robust.Shared.Input.Binding
//add the graph edges
foreach (var curBinding in allNodes)
{
foreach (var afterType in curBinding.TypedCommandBind.CommandBind.After)
foreach (var afterType in curBinding.Value.CommandBind.After)
{
// curBinding should always fire after bindings associated with this afterType, i.e.
// this binding DEPENDS ON afterTypes' bindings
@@ -130,11 +136,12 @@ namespace Robust.Shared.Input.Binding
{
foreach (var afterBinding in afterBindings)
{
curBinding.DependsOn.Add(afterBinding);
afterBinding.Dependant.Add(curBinding);
}
}
}
foreach (var beforeType in curBinding.TypedCommandBind.CommandBind.Before)
foreach (var beforeType in curBinding.Value.CommandBind.Before)
{
// curBinding should always fire before bindings associated with this beforeType, i.e.
// beforeTypes' bindings DEPENDS ON this binding
@@ -142,7 +149,7 @@ namespace Robust.Shared.Input.Binding
{
foreach (var beforeBinding in beforeBindings)
{
beforeBinding.DependsOn.Add(curBinding);
curBinding.Dependant.Add(beforeBinding);
}
}
}
@@ -151,54 +158,7 @@ namespace Robust.Shared.Input.Binding
//TODO: Log graph structure for debugging
//use toposort to build the final result
var topoSorted = TopologicalSort(allNodes, function);
List<InputCmdHandler> result = new();
foreach (var node in topoSorted)
{
result.Add(node.TypedCommandBind.CommandBind.Handler);
}
return result;
}
//Adapted from https://stackoverflow.com/a/24058279
private static IEnumerable<GraphNode> TopologicalSort(IEnumerable<GraphNode> nodes, BoundKeyFunction function)
{
var elems = nodes.ToDictionary(node => node,
node => new HashSet<GraphNode>(node.DependsOn));
while (elems.Count > 0)
{
var elem =
elems.FirstOrDefault(x => x.Value.Count == 0);
if (elem.Key == null)
{
throw new InvalidOperationException("Found circular dependency when resolving" +
$" command binding handler order for key function {function.FunctionName}." +
$" Please check the systems which register bindings for" +
$" this function and eliminate the circular dependency.");
}
elems.Remove(elem.Key);
foreach (var selem in elems)
{
selem.Value.Remove(elem.Key);
}
yield return elem.Key;
}
}
/// <summary>
/// node in our temporary dependency graph
/// </summary>
private class GraphNode
{
public List<GraphNode> DependsOn = new();
public readonly TypedCommandBind TypedCommandBind;
public GraphNode(TypedCommandBind typedCommandBind)
{
TypedCommandBind = typedCommandBind;
}
return TopologicalSort.Sort(allNodes).Select(c => c.CommandBind.Handler).ToList();
}
/// <summary>

View File

@@ -0,0 +1,125 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
namespace Robust.Shared.Utility
{
public sealed class TopologicalSort
{
public static IEnumerable<T> Sort<T>(IEnumerable<GraphNode<T>> nodes)
{
var totalVerts = 0;
var empty = new Queue<GraphNode<T>>();
var nodesArray = nodes.ToArray();
foreach (var node in nodesArray)
{
totalVerts += 1;
foreach (var dep in node.Dependant)
{
dep.DependsOnCount += 1;
}
}
foreach (var node in nodesArray)
{
if (node.DependsOnCount == 0)
empty.Enqueue(node);
}
while (empty.TryDequeue(out var node))
{
yield return node.Value;
totalVerts -= 1;
foreach (var dep in node.Dependant)
{
dep.DependsOnCount -= 1;
if (dep.DependsOnCount == 0)
empty.Enqueue(dep);
}
}
if (totalVerts != 0)
throw new InvalidOperationException("Graph contained cycle(s).");
}
// I will never stop using the word "datum".
public static IEnumerable<GraphNode<TValue>> FromBeforeAfter<TDatum, TValue>(
IEnumerable<TDatum> data,
Func<TDatum, TValue> keySelector,
Func<TDatum, IEnumerable<TValue>> beforeSelector,
Func<TDatum, IEnumerable<TValue>> afterSelector,
bool allowMissing = false)
where TValue : notnull
{
return FromBeforeAfter(data, keySelector, keySelector, afterSelector, beforeSelector, allowMissing);
}
public static IEnumerable<GraphNode<TValue>> FromBeforeAfter<TDatum, TKey, TValue>(
IEnumerable<TDatum> data,
Func<TDatum, TKey> keySelector,
Func<TDatum, TValue> valueSelector,
Func<TDatum, IEnumerable<TKey>> beforeSelector,
Func<TDatum, IEnumerable<TKey>> afterSelector,
bool allowMissing=false)
where TKey : notnull
{
var dict = new Dictionary<TKey, (TDatum datum, GraphNode<TValue> node)>();
foreach (var datum in data)
{
var key = keySelector(datum);
var value = valueSelector(datum);
dict.Add(key, (datum, new GraphNode<TValue>(value)));
}
foreach (var (key, (datum, node)) in dict)
{
foreach (var before in beforeSelector(datum))
{
if (dict.TryGetValue(before, out var entry))
{
node.Dependant.Add(entry.node);
}
else if (!allowMissing)
{
throw new InvalidOperationException($"Vertex '{before}' referenced by '{key}' was not found in the graph.");
}
}
foreach (var after in afterSelector(datum))
{
if (dict.TryGetValue(after, out var entry))
{
entry.node.Dependant.Add(node);
}
else if (!allowMissing)
{
throw new InvalidOperationException($"Vertex '{after}' referenced by '{key}' was not found in the graph.");
}
}
}
return dict.Values.Select(c => c.node);
}
[DebuggerDisplay("GraphNode: {" + nameof(Value) + "}")]
public class GraphNode<T>
{
public readonly T Value;
public readonly List<GraphNode<T>> Dependant = new();
// Used internal by sort implementation, do not touch.
internal int DependsOnCount;
public GraphNode(T value)
{
Value = value;
}
}
}
}

View File

@@ -166,11 +166,98 @@ namespace Robust.UnitTesting.Shared.GameObjects
}
}
[Test]
public void CompEventOrdered()
{
// Arrange
var entUid = new EntityUid(7);
var entManMock = new Mock<IEntityManager>();
var compManMock = new Mock<IComponentManager>();
var compFacMock = new Mock<IComponentFactory>();
void Setup<T>(out T instance) where T : IComponent, new()
{
IComponent? inst = instance = new T();
var reg = new Mock<IComponentRegistration>();
reg.Setup(m => m.References).Returns(new Type[] {typeof(T)});
compFacMock.Setup(m => m.GetRegistration(typeof(T))).Returns(reg.Object);
compManMock.Setup(m => m.TryGetComponent(entUid, typeof(T), out inst)).Returns(true);
compManMock.Setup(m => m.GetComponent(entUid, typeof(T))).Returns(inst);
}
Setup<OrderComponentA>(out var instA);
Setup<OrderComponentB>(out var instB);
Setup<OrderComponentC>(out var instC);
compManMock.Setup(m => m.ComponentFactory).Returns(compFacMock.Object);
entManMock.Setup(m => m.ComponentManager).Returns(compManMock.Object);
var bus = new EntityEventBus(entManMock.Object);
// Subscribe
var a = false;
var b = false;
var c = false;
void HandlerA(EntityUid uid, Component comp, TestEvent ev)
{
Assert.That(b, Is.False, "A should run before B");
Assert.That(c, Is.False, "A should run before C");
a = true;
}
void HandlerB(EntityUid uid, Component comp, TestEvent ev)
{
Assert.That(c, Is.True, "B should run after C");
b = true;
}
void HandlerC(EntityUid uid, Component comp, TestEvent ev) => c = true;
bus.SubscribeLocalEvent<OrderComponentA, TestEvent>(HandlerA, typeof(OrderComponentA), before: new []{typeof(OrderComponentB), typeof(OrderComponentC)});
bus.SubscribeLocalEvent<OrderComponentB, TestEvent>(HandlerB, typeof(OrderComponentB), after: new []{typeof(OrderComponentC)});
bus.SubscribeLocalEvent<OrderComponentC, TestEvent>(HandlerC, typeof(OrderComponentC));
// add a component to the system
entManMock.Raise(m=>m.EntityAdded += null, entManMock.Object, entUid);
compManMock.Raise(m => m.ComponentAdded += null, new AddedComponentEventArgs(instA, entUid));
compManMock.Raise(m => m.ComponentAdded += null, new AddedComponentEventArgs(instB, entUid));
compManMock.Raise(m => m.ComponentAdded += null, new AddedComponentEventArgs(instC, entUid));
// Raise
var evntArgs = new TestEvent(5);
bus.RaiseLocalEvent(entUid, evntArgs);
// Assert
Assert.That(a, Is.True, "A did not fire");
Assert.That(b, Is.True, "B did not fire");
Assert.That(c, Is.True, "C did not fire");
}
private class DummyComponent : Component
{
public override string Name => "Dummy";
}
private class OrderComponentA : Component
{
public override string Name => "OrderComponentA";
}
private class OrderComponentB : Component
{
public override string Name => "OrderComponentB";
}
private class OrderComponentC : Component
{
public override string Name => "OrderComponentC";
}
private class TestEvent : EntityEventArgs
{
public int TestNumber { get; }

View File

@@ -469,6 +469,58 @@ namespace Robust.UnitTesting.Shared.GameObjects
// Assert
Assert.Throws<InvalidOperationException>(Code);
}
[Test]
public void RaiseEvent_Ordered()
{
// Arrange
var bus = BusFactory();
// Expected order is A -> C -> B
var a = false;
var b = false;
var c = false;
void HandlerA(TestEventArgs ev)
{
Assert.That(b, Is.False, "A should run before B");
Assert.That(c, Is.False, "A should run before C");
a = true;
}
void HandlerB(TestEventArgs ev)
{
Assert.That(c, Is.True, "B should run after C");
b = true;
}
void HandlerC(TestEventArgs ev) => c = true;
bus.SubscribeEvent<TestEventArgs>(EventSource.Local, new SubA(), HandlerA, typeof(SubA), before: new []{typeof(SubB), typeof(SubC)});
bus.SubscribeEvent<TestEventArgs>(EventSource.Local, new SubB(), HandlerB, typeof(SubB), after: new []{typeof(SubC)});
bus.SubscribeEvent<TestEventArgs>(EventSource.Local, new SubC(), HandlerC, typeof(SubC));
// Act
bus.RaiseEvent(EventSource.Local, new TestEventArgs());
// Assert
Assert.That(a, Is.True, "A did not fire");
Assert.That(b, Is.True, "B did not fire");
Assert.That(c, Is.True, "C did not fire");
}
public sealed class SubA : IEntityEventSubscriber
{
}
public sealed class SubB : IEntityEventSubscriber
{
}
public sealed class SubC : IEntityEventSubscriber
{
}
}
internal class TestEventSubscriber : IEntityEventSubscriber { }

View File

@@ -222,12 +222,12 @@ namespace Robust.UnitTesting.Shared.Input.Binding
.Bind(bkf, bHandler1)
.Bind(bkf, bHandler2)
.Register<TypeB>(registry);
CommandBinds.Builder
.Bind(bkf, cHandler1)
.BindAfter(bkf, cHandler2, typeof(TypeA))
.Register<TypeC>(registry);
Assert.Throws<InvalidOperationException>(() =>
CommandBinds.Builder
.Bind(bkf, cHandler1)
.BindAfter(bkf, cHandler2, typeof(TypeA))
.Register<TypeC>(registry));
Assert.Throws<InvalidOperationException>(registry.RebuildGraph);
}
}
}