diff --git a/Robust.Shared/ContentPack/ModLoader.cs b/Robust.Shared/ContentPack/ModLoader.cs index af3233bda..6fc82a2d6 100644 --- a/Robust.Shared/ContentPack/ModLoader.cs +++ b/Robust.Shared/ContentPack/ModLoader.cs @@ -106,8 +106,16 @@ namespace Robust.Shared.ContentPack Logger.DebugS("res.mod", $"Verified assemblies in {checkerSw.ElapsedMilliseconds}ms"); } + var nodes = TopologicalSort.FromBeforeAfter( + files, + kv => kv.Key, + kv => kv.Value.Path, + _ => Array.Empty(), + kv => kv.Value.references, + allowMissing: true); // missing refs would be non-content assemblies so allow that. + // Actually load them in the order they depend on each other. - foreach (var path in TopologicalSortModules(files)) + foreach (var path in TopologicalSort.Sort(nodes)) { Logger.DebugS("res.mod", $"Loading module: '{path}'"); try @@ -136,38 +144,6 @@ namespace Robust.Shared.ContentPack return true; } - private static IEnumerable TopologicalSortModules( - IEnumerable> modules) - { - var elems = modules.ToDictionary( - node => node.Key, - node => (node.Value.Path, refs: new HashSet(node.Value.references))); - - // Remove assembly references we aren't sorting for. - foreach (var (_, set) in elems.Values) - { - set.RemoveWhere(r => !elems.ContainsKey(r)); - } - - while (elems.Count > 0) - { - var elem = elems.FirstOrNull(x => x.Value.refs.Count == 0); - if (elem == null) - { - throw new InvalidOperationException( - "Found circular dependency in assembly dependency graph"); - } - - elems.Remove(elem.Value.Key); - foreach (var sElem in elems) - { - sElem.Value.refs.Remove(elem.Value.Key); - } - - yield return elem.Value.Value.Path; - } - } - private static (string[] refs, string name) GetAssemblyReferenceData(Stream stream) { using var reader = new PEReader(stream); diff --git a/Robust.Shared/GameObjects/EntityEventBus.Broadcast.cs b/Robust.Shared/GameObjects/EntityEventBus.Broadcast.cs index 20103374c..1c43d7fad 100644 --- a/Robust.Shared/GameObjects/EntityEventBus.Broadcast.cs +++ b/Robust.Shared/GameObjects/EntityEventBus.Broadcast.cs @@ -24,6 +24,15 @@ namespace Robust.Shared.GameObjects void SubscribeEvent(EventSource source, IEntityEventSubscriber subscriber, EntityEventHandler eventHandler) where T : notnull; + void SubscribeEvent( + EventSource source, + IEntityEventSubscriber subscriber, + EntityEventHandler eventHandler, + Type orderType, + Type[]? before=null, + Type[]? after=null) + where T : notnull; + /// /// Unsubscribes all event handlers of a given type. /// @@ -137,7 +146,7 @@ namespace Robust.Shared.GameObjects return; // UnsubscribeEvent modifies _inverseEventSubscriptions, requires val to be cached - foreach (var (type, (source, originalHandler, handler)) in val.ToList()) + foreach (var (type, (source, originalHandler, handler, _)) in val.ToList()) { UnsubscribeEvent(source, type, originalHandler, handler, subscriber); } @@ -154,7 +163,36 @@ namespace Robust.Shared.GameObjects } /// - public void SubscribeEvent(EventSource source, IEntityEventSubscriber subscriber, EntityEventHandler eventHandler) where T : notnull + public void SubscribeEvent( + EventSource source, + IEntityEventSubscriber subscriber, + EntityEventHandler eventHandler) + where T : notnull + { + SubscribeEventCommon(source, subscriber, eventHandler, null); + } + + public void SubscribeEvent( + EventSource source, + IEntityEventSubscriber subscriber, + EntityEventHandler eventHandler, + Type orderType, + Type[]? before=null, + Type[]? after=null) + where T : notnull + { + var order = new OrderingData(orderType, before, after); + + SubscribeEventCommon(source, subscriber, eventHandler, order); + HandleOrderRegistration(typeof(T), order); + } + + private void SubscribeEventCommon( + EventSource source, + IEntityEventSubscriber subscriber, + EntityEventHandler eventHandler, + OrderingData? order) + where T : notnull { if (source == EventSource.None) throw new ArgumentOutOfRangeException(nameof(source)); @@ -166,7 +204,7 @@ namespace Robust.Shared.GameObjects throw new ArgumentNullException(nameof(subscriber)); var eventType = typeof(T); - var subscriptionTuple = new Registration(source, eventHandler, ev => eventHandler((T) ev), eventHandler); + var subscriptionTuple = new Registration(source, eventHandler, ev => eventHandler((T) ev), eventHandler, order); if (!_eventSubscriptions.TryGetValue(eventType, out var subscriptions)) _eventSubscriptions.Add(eventType, new List {subscriptionTuple}); else if (!subscriptions.Any(p => p.Mask == source && p.Original == (Delegate) eventHandler)) @@ -184,7 +222,6 @@ namespace Robust.Shared.GameObjects inverseSubscription ); } - else if (!inverseSubscription.ContainsKey(eventType)) { inverseSubscription.Add(eventType, subscriptionTuple); @@ -290,15 +327,13 @@ namespace Robust.Shared.GameObjects }); } - - _awaitingMessages.Add(type, (source, reg, tcs)); return tcs.Task; } private void UnsubscribeEvent(EventSource source, Type eventType, Delegate originalHandler, EventHandler handler, IEntityEventSubscriber subscriber) { - var tuple = new Registration(source, originalHandler, handler, originalHandler); + var tuple = new Registration(source, originalHandler, handler, originalHandler, null); if (_eventSubscriptions.TryGetValue(eventType, out var subscriptions) && subscriptions.Contains(tuple)) subscriptions.Remove(tuple); @@ -310,7 +345,11 @@ namespace Robust.Shared.GameObjects { var eventType = eventArgs.GetType(); - if (_eventSubscriptions.TryGetValue(eventType, out var subs)) + if (_orderedEvents.Contains(eventType)) + { + ProcessSingleEventOrdered(source, eventArgs, eventType); + } + else if (_eventSubscriptions.TryGetValue(eventType, out var subs)) { foreach (var handler in subs) { @@ -319,25 +358,20 @@ namespace Robust.Shared.GameObjects } } - if (_awaitingMessages.TryGetValue(eventType, out var awaiting)) - { - var (mask, _, tcs) = awaiting; - - if ((source & mask) != 0) - { - tcs.TrySetResult(eventArgs); - _awaitingMessages.Remove(eventType); - } - } + ProcessAwaitingMessages(source, eventArgs, eventType); } private void ProcessSingleEvent(EventSource source, T eventArgs) where T : notnull { var eventType = typeof(T); - if (_eventSubscriptions.TryGetValue(eventType, out var subs)) + if (_orderedEvents.Contains(eventType)) { - foreach (var (mask, originalHandler, _) in subs) + ProcessSingleEventOrdered(source, eventArgs, eventType); + } + else if (_eventSubscriptions.TryGetValue(eventType, out var subs)) + { + foreach (var (mask, originalHandler, _, _) in subs) { if ((mask & source) != 0) { @@ -347,6 +381,13 @@ namespace Robust.Shared.GameObjects } } + ProcessAwaitingMessages(source, eventArgs, eventType); + } + + // Generic here so we can avoid boxing alloc unless actually awaiting. + private void ProcessAwaitingMessages(EventSource source, T eventArgs, Type eventType) + where T : notnull + { if (_awaitingMessages.TryGetValue(eventType, out var awaiting)) { var (mask1, _, tcs) = awaiting; @@ -366,13 +407,20 @@ namespace Robust.Shared.GameObjects public readonly Delegate Original; public readonly EventHandler Handler; + public readonly OrderingData? Ordering; - public Registration(EventSource mask, Delegate original, EventHandler handler, object equalityToken) + public Registration( + EventSource mask, + Delegate original, + EventHandler handler, + object equalityToken, + OrderingData? ordering) { Mask = mask; Original = original; Handler = handler; EqualityToken = equalityToken; + Ordering = ordering; } public bool Equals(Registration other) @@ -403,11 +451,16 @@ namespace Robust.Shared.GameObjects return !left.Equals(right); } - public void Deconstruct(out EventSource mask, out Delegate originalHandler, out EventHandler handler) + public void Deconstruct( + out EventSource mask, + out Delegate originalHandler, + out EventHandler handler, + out OrderingData? order) { mask = Mask; originalHandler = Original; handler = Handler; + order = Ordering; } } } diff --git a/Robust.Shared/GameObjects/EntityEventBus.Directed.cs b/Robust.Shared/GameObjects/EntityEventBus.Directed.cs index 579aa21d8..e0470b320 100644 --- a/Robust.Shared/GameObjects/EntityEventBus.Directed.cs +++ b/Robust.Shared/GameObjects/EntityEventBus.Directed.cs @@ -14,6 +14,12 @@ namespace Robust.Shared.GameObjects where TComp : IComponent where TEvent : EntityEventArgs; + void SubscribeLocalEvent( + ComponentEventHandler handler, + Type orderType, Type[]? before=null, Type[]? after=null) + where TComp : IComponent + where TEvent : EntityEventArgs; + [Obsolete("Use the overload without the handler argument.")] void UnsubscribeLocalEvent(ComponentEventHandler handler) where TComp : IComponent @@ -61,6 +67,12 @@ namespace Robust.Shared.GameObjects public void RaiseLocalEvent(EntityUid uid, TEvent args, bool broadcast = true) where TEvent : EntityEventArgs { + if (_orderedEvents.Contains(typeof(TEvent))) + { + RaiseLocalOrdered(uid, args, broadcast); + return; + } + _eventTables.Dispatch(uid, typeof(TEvent), args); // we also broadcast it so the call site does not have to. @@ -76,7 +88,24 @@ namespace Robust.Shared.GameObjects void EventHandler(EntityUid uid, IComponent comp, EntityEventArgs args) => handler(uid, (TComp) comp, (TEvent) args); - _eventTables.Subscribe(typeof(TComp), typeof(TEvent), EventHandler); + _eventTables.Subscribe(typeof(TComp), typeof(TEvent), EventHandler, null); + } + + public void SubscribeLocalEvent( + ComponentEventHandler handler, + Type orderType, + Type[]? before=null, + Type[]? after=null) + where TComp : IComponent + where TEvent : EntityEventArgs + { + void EventHandler(EntityUid uid, IComponent comp, EntityEventArgs args) + => handler(uid, (TComp) comp, (TEvent) args); + + var orderData = new OrderingData(orderType, before, after); + + _eventTables.Subscribe(typeof(TComp), typeof(TEvent), EventHandler, orderData); + HandleOrderRegistration(typeof(TEvent), orderData); } /// @@ -105,7 +134,7 @@ namespace Robust.Shared.GameObjects private Dictionary>> _eventTables; // EventType -> CompType -> Handler - private Dictionary> _subscriptions; + private Dictionary> _subscriptions; // prevents shitcode, get your subscriptions figured out before you start spawning entities private bool _subscriptionLock; @@ -148,25 +177,21 @@ namespace Robust.Shared.GameObjects RemoveComponent(e.OwnerUid, e.Component.GetType()); } - public void Subscribe(Type compType, Type eventType, DirectedEventHandler handler) + public void Subscribe(Type compType, Type eventType, DirectedEventHandler handler, OrderingData? order) { if (_subscriptionLock) throw new InvalidOperationException("Subscription locked."); if (!_subscriptions.TryGetValue(compType, out var compSubs)) { - compSubs = new Dictionary(); + compSubs = new Dictionary(); _subscriptions.Add(compType, compSubs); - - compSubs.Add(eventType, handler); } - else - { - if (compSubs.ContainsKey(eventType)) - throw new InvalidOperationException($"Duplicate Subscriptions for comp={compType.Name}, event={eventType.Name}"); - compSubs.Add(eventType, handler); - } + if (compSubs.ContainsKey(eventType)) + throw new InvalidOperationException($"Duplicate Subscriptions for comp={compType.Name}, event={eventType.Name}"); + + compSubs.Add(eventType, (handler, order)); } public void Unsubscribe(Type compType, Type eventType) @@ -245,15 +270,41 @@ namespace Robust.Shared.GameObjects if(!_subscriptions.TryGetValue(compType, out var compSubs)) return; - if(!compSubs.TryGetValue(eventType, out var handler)) + if(!compSubs.TryGetValue(eventType, out var sub)) return; + var (handler, _) = sub; var component = _entMan.ComponentManager.GetComponent(euid, compType); handler(euid, component, args); } } + public void CollectOrdered( + EntityUid euid, + Type eventType, + List<(EventHandler, OrderingData?)> found) + { + var eventTable = _eventTables[euid]; + + if(!eventTable.TryGetValue(eventType, out var subscribedComps)) + return; + + foreach (var compType in subscribedComps) + { + if(!_subscriptions.TryGetValue(compType, out var compSubs)) + return; + + if(!compSubs.TryGetValue(eventType, out var sub)) + return; + + var (handler, order) = sub; + var component = _entMan.ComponentManager.GetComponent(euid, compType); + + found.Add((ev => handler(euid, component, (EntityEventArgs) ev), order)); + } + } + public void DispatchComponent(EntityUid euid, IComponent component, Type eventType, EntityEventArgs args) { foreach (var type in GetReferences(component.GetType())) @@ -261,9 +312,10 @@ namespace Robust.Shared.GameObjects if (!_subscriptions.TryGetValue(type, out var compSubs)) continue; - if (!compSubs.TryGetValue(eventType, out var handler)) + if (!compSubs.TryGetValue(eventType, out var sub)) continue; + var (handler, _) = sub; handler(euid, component, args); } } diff --git a/Robust.Shared/GameObjects/EntityEventBus.Ordering.cs b/Robust.Shared/GameObjects/EntityEventBus.Ordering.cs new file mode 100644 index 000000000..4f38deae2 --- /dev/null +++ b/Robust.Shared/GameObjects/EntityEventBus.Ordering.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Robust.Shared.Utility; + +namespace Robust.Shared.GameObjects +{ + internal partial class EntityEventBus + { + // TODO: Topological sort is currently done every time an event is emitted. + // This should be fine for low-volume stuff like interactions, but definitely not for anything high volume. + // Not sure if we could pre-cache the topological sort, here. + + // Ordered event raising is slow so if this event has any ordering dependencies we use a slower path. + private readonly HashSet _orderedEvents = new(); + + private void ProcessSingleEventOrdered(EventSource source, object eventArgs, Type eventType) + { + var found = new List<(EventHandler, OrderingData?)>(); + + CollectBroadcastOrdered(source, eventType, found); + + DispatchOrderedEvents(eventArgs, found); + } + + private void CollectBroadcastOrdered( + EventSource source, + Type eventType, + List<(EventHandler, OrderingData?)> found) + { + if (!_eventSubscriptions.TryGetValue(eventType, out var subs)) + return; + + foreach (var handler in subs) + { + if ((handler.Mask & source) != 0) + found.Add((handler.Handler, handler.Ordering)); + } + } + + private void RaiseLocalOrdered( + EntityUid uid, + TEvent args, + bool broadcast) + where TEvent : EntityEventArgs + { + var found = new List<(EventHandler, OrderingData?)>(); + + if (broadcast) + CollectBroadcastOrdered(EventSource.Local, typeof(TEvent), found); + + _eventTables.CollectOrdered(uid, typeof(TEvent), found); + + DispatchOrderedEvents(args, found); + + if (broadcast) + ProcessAwaitingMessages(EventSource.Local, args, typeof(TEvent)); + } + + private static void DispatchOrderedEvents(object eventArgs, List<(EventHandler, OrderingData?)> found) + { + var nodes = TopologicalSort.FromBeforeAfter( + found.Where(f => f.Item2 != null), + n => n.Item2!.OrderType, + n => n.Item1!, + n => n.Item2!.Before ?? Array.Empty(), + n => n.Item2!.After ?? Array.Empty(), + allowMissing: true); + + foreach (var handler in TopologicalSort.Sort(nodes)) + { + handler(eventArgs); + } + + // Go over all handlers that don't have ordering so weren't included in the sort. + foreach (var (handler, orderData) in found) + { + if (orderData == null) + handler(eventArgs); + } + } + + private void HandleOrderRegistration(Type eventType, OrderingData? data) + { + if (data == null) + return; + + if (data.Before != null || data.After != null) + _orderedEvents.Add(eventType); + } + + private sealed record OrderingData(Type OrderType, Type[]? Before, Type[]? After); + + } +} diff --git a/Robust.Shared/GameObjects/EntitySystem.Subscriptions.cs b/Robust.Shared/GameObjects/EntitySystem.Subscriptions.cs index c947bc5fb..b8d4639d4 100644 --- a/Robust.Shared/GameObjects/EntitySystem.Subscriptions.cs +++ b/Robust.Shared/GameObjects/EntitySystem.Subscriptions.cs @@ -7,16 +7,22 @@ namespace Robust.Shared.GameObjects { private List? _subscriptions; - protected void SubscribeNetworkEvent(EntityEventHandler handler) + // NOTE: EntityEventHandler and EntitySessionEventHandler CANNOT BE ORDERED BETWEEN EACH OTHER. + + protected void SubscribeNetworkEvent( + EntityEventHandler handler, + Type[]? before=null, Type[]? after=null) where T : notnull { - EntityManager.EventBus.SubscribeEvent(EventSource.Network, this, handler); + EntityManager.EventBus.SubscribeEvent(EventSource.Network, this, handler, GetType(), before, after); _subscriptions ??= new(); _subscriptions.Add(new SubBroadcast(EventSource.Network)); } - protected void SubscribeNetworkEvent(EntitySessionEventHandler handler) + protected void SubscribeNetworkEvent( + EntitySessionEventHandler handler, + Type[]? before=null, Type[]? after=null) where T : notnull { EntityManager.EventBus.SubscribeSessionEvent(EventSource.Network, this, handler); @@ -25,16 +31,20 @@ namespace Robust.Shared.GameObjects _subscriptions.Add(new SubBroadcast>(EventSource.Network)); } - protected void SubscribeLocalEvent(EntityEventHandler handler) + protected void SubscribeLocalEvent( + EntityEventHandler handler, + Type[]? before=null, Type[]? after=null) where T : notnull { - EntityManager.EventBus.SubscribeEvent(EventSource.Local, this, handler); + EntityManager.EventBus.SubscribeEvent(EventSource.Local, this, handler, GetType(), before, after); _subscriptions ??= new(); _subscriptions.Add(new SubBroadcast(EventSource.Local)); } - protected void SubscribeLocalEvent(EntitySessionEventHandler handler) + protected void SubscribeLocalEvent( + EntitySessionEventHandler handler, + Type[]? before=null, Type[]? after=null) where T : notnull { EntityManager.EventBus.SubscribeSessionEvent(EventSource.Local, this, handler); @@ -57,8 +67,9 @@ namespace Robust.Shared.GameObjects EntityManager.EventBus.UnsubscribeEvent(EventSource.Local, this); } - - protected void SubscribeLocalEvent(ComponentEventHandler handler) + protected void SubscribeLocalEvent( + ComponentEventHandler handler, + Type[]? before=null, Type[]? after=null) where TComp : IComponent where TEvent : EntityEventArgs { diff --git a/Robust.Shared/GameObjects/EntitySystemManager.cs b/Robust.Shared/GameObjects/EntitySystemManager.cs index c6f075cd6..f6d3ddc54 100644 --- a/Robust.Shared/GameObjects/EntitySystemManager.cs +++ b/Robust.Shared/GameObjects/EntitySystemManager.cs @@ -166,15 +166,15 @@ namespace Robust.Shared.GameObjects Dictionary.ValueCollection systems, Dictionary supertypeSystems) { - var allNodes = new List>(); - var typeToNode = new Dictionary>(); + var allNodes = new List>(); + var typeToNode = new Dictionary>(); foreach (var system in systems) { - var node = new GraphNode(system); + var node = new TopologicalSort.GraphNode(system); - allNodes.Add(node); typeToNode.Add(system.GetType(), node); + allNodes.Add(node); } foreach (var (type, system) in supertypeSystems) @@ -183,52 +183,30 @@ namespace Robust.Shared.GameObjects typeToNode[type] = node; } - foreach (var node in allNodes) + foreach (var node in typeToNode.Values) { - foreach (var after in node.System.UpdatesAfter) + foreach (var after in node.Value.UpdatesAfter) { var system = typeToNode[after]; - node.DependsOn.Add(system); + system.Dependant.Add(node); } - foreach (var before in node.System.UpdatesBefore) + foreach (var before in node.Value.UpdatesBefore) { var system = typeToNode[before]; - system.DependsOn.Add(node); + node.Dependant.Add(system); } } - var order = TopologicalSort(allNodes).Select(p => p.System).ToArray(); + var order = TopologicalSort.Sort(allNodes).ToArray(); var frameUpdate = order.Where(p => NeedsFrameUpdate(p.GetType())); var update = order.Where(p => NeedsUpdate(p.GetType())); return (frameUpdate, update); } - internal static IEnumerable> TopologicalSort(IEnumerable> nodes) - { - var elems = nodes.ToDictionary(node => node, - node => new HashSet>(node.DependsOn)); - while (elems.Count > 0) - { - var elem = - elems.FirstOrDefault(x => x.Value.Count == 0); - if (elem.Key == null) - { - throw new InvalidOperationException( - "Found circular dependency when resolving entity system update dependency graph"); - } - elems.Remove(elem.Key); - foreach (var selem in elems) - { - selem.Value.Remove(elem.Key); - } - yield return elem.Key; - } - } - private static IEnumerable GetBaseTypes(Type type) { if(type.BaseType == null) return type.GetInterfaces(); @@ -344,18 +322,6 @@ namespace Robust.Shared.GameObjects return mFrameUpdate!.DeclaringType != typeof(EntitySystem); } - [DebuggerDisplay("GraphNode: {" + nameof(System) + "}")] - internal sealed class GraphNode - { - public readonly T System; - public readonly List> DependsOn = new(); - - public GraphNode(T system) - { - System = system; - } - } - private struct UpdateReg { [ViewVariables] public IEntitySystem System; diff --git a/Robust.Shared/GameObjects/Systems/SharedPhysicsSystem.cs b/Robust.Shared/GameObjects/Systems/SharedPhysicsSystem.cs index bc7d68e75..3dc209ddf 100644 --- a/Robust.Shared/GameObjects/Systems/SharedPhysicsSystem.cs +++ b/Robust.Shared/GameObjects/Systems/SharedPhysicsSystem.cs @@ -14,6 +14,7 @@ using Robust.Shared.Physics.Controllers; using Robust.Shared.Physics.Dynamics; using Robust.Shared.Reflection; using Robust.Shared.Timing; +using Robust.Shared.Utility; using DependencyAttribute = Robust.Shared.IoC.DependencyAttribute; using Logger = Robust.Shared.Log.Logger; @@ -114,45 +115,24 @@ namespace Robust.Shared.GameObjects { var reflectionManager = IoCManager.Resolve(); var typeFactory = IoCManager.Resolve(); - var allControllerTypes = new List(); + var instantiated = new List(); foreach (var type in reflectionManager.GetAllChildren(typeof(VirtualController))) { - if (type.IsAbstract) continue; - allControllerTypes.Add(type); + if (type.IsAbstract) + continue; + + instantiated.Add(typeFactory.CreateInstance(type)); } - var instantiated = new Dictionary(); + var nodes = TopologicalSort.FromBeforeAfter( + instantiated, + c => c.GetType(), + c => c, + c => c.UpdatesBefore, + c => c.UpdatesAfter); - foreach (var type in allControllerTypes) - { - instantiated.Add(type, (VirtualController) typeFactory.CreateInstance(type)); - } - - // Build dependency graph, copied from EntitySystemManager *COUGH - - var nodes = new Dictionary>(); - - foreach (var (_, controller) in instantiated) - { - var node = new EntitySystemManager.GraphNode(controller); - nodes[controller.GetType()] = node; - } - - foreach (var (type, node) in nodes) - { - foreach (var before in instantiated[type].UpdatesBefore) - { - nodes[before].DependsOn.Add(node); - } - - foreach (var after in instantiated[type].UpdatesAfter) - { - node.DependsOn.Add(nodes[after]); - } - } - - _controllers = GameObjects.EntitySystemManager.TopologicalSort(nodes.Values).Select(c => c.System).ToList(); + _controllers = TopologicalSort.Sort(nodes).ToList(); foreach (var controller in _controllers) { diff --git a/Robust.Shared/Input/Binding/CommandBindRegistry.cs b/Robust.Shared/Input/Binding/CommandBindRegistry.cs index c6cb01800..8ebe4bd41 100644 --- a/Robust.Shared/Input/Binding/CommandBindRegistry.cs +++ b/Robust.Shared/Input/Binding/CommandBindRegistry.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.Linq; using Robust.Shared.Log; +using Robust.Shared.Utility; namespace Robust.Shared.Input.Binding { @@ -13,8 +14,8 @@ namespace Robust.Shared.Input.Binding // handlers in the order they should be resolved for the given key function. // internally we use a graph to construct this but we render it down to a flattened // list so we don't need to do any graph traversal at query time - private Dictionary> _bindingsForKey = - new(); + private Dictionary> _bindingsForKey = new(); + private bool _graphDirty = false; /// public void Register(CommandBinds commandBinds) @@ -41,12 +42,15 @@ namespace Robust.Shared.Input.Binding _bindings.Add(new TypedCommandBind(owner, binding)); } - RebuildGraph(); + _graphDirty = true; } /// public IEnumerable GetHandlers(BoundKeyFunction function) { + if (_graphDirty) + RebuildGraph(); + if (_bindingsForKey.TryGetValue(function, out var handlers)) { return handlers; @@ -58,7 +62,8 @@ namespace Robust.Shared.Input.Binding public void Unregister(Type owner) { _bindings.RemoveAll(binding => binding.ForType == owner); - RebuildGraph(); + + _graphDirty = true; } /// @@ -67,7 +72,7 @@ namespace Robust.Shared.Input.Binding Unregister(typeof(TOwner)); } - private void RebuildGraph() + internal void RebuildGraph() { _bindingsForKey.Clear(); @@ -77,6 +82,7 @@ namespace Robust.Shared.Input.Binding } + _graphDirty = false; } private Dictionary> FunctionToBindings() @@ -105,16 +111,16 @@ namespace Robust.Shared.Input.Binding //TODO: Probably could be optimized if needed! Generally shouldn't be a big issue since there is a relatively // tiny amount of bindings - List allNodes = new(); - Dictionary> typeToNode = new(); + List> allNodes = new(); + Dictionary>> typeToNode = new(); // build the dict for quick lookup on type foreach (var binding in bindingsForFunction) { if (!typeToNode.ContainsKey(binding.ForType)) { - typeToNode[binding.ForType] = new List(); + typeToNode[binding.ForType] = new List>(); } - var newNode = new GraphNode(binding); + var newNode = new TopologicalSort.GraphNode(binding); typeToNode[binding.ForType].Add(newNode); allNodes.Add(newNode); } @@ -122,7 +128,7 @@ namespace Robust.Shared.Input.Binding //add the graph edges foreach (var curBinding in allNodes) { - foreach (var afterType in curBinding.TypedCommandBind.CommandBind.After) + foreach (var afterType in curBinding.Value.CommandBind.After) { // curBinding should always fire after bindings associated with this afterType, i.e. // this binding DEPENDS ON afterTypes' bindings @@ -130,11 +136,12 @@ namespace Robust.Shared.Input.Binding { foreach (var afterBinding in afterBindings) { - curBinding.DependsOn.Add(afterBinding); + afterBinding.Dependant.Add(curBinding); } } } - foreach (var beforeType in curBinding.TypedCommandBind.CommandBind.Before) + + foreach (var beforeType in curBinding.Value.CommandBind.Before) { // curBinding should always fire before bindings associated with this beforeType, i.e. // beforeTypes' bindings DEPENDS ON this binding @@ -142,7 +149,7 @@ namespace Robust.Shared.Input.Binding { foreach (var beforeBinding in beforeBindings) { - beforeBinding.DependsOn.Add(curBinding); + curBinding.Dependant.Add(beforeBinding); } } } @@ -151,54 +158,7 @@ namespace Robust.Shared.Input.Binding //TODO: Log graph structure for debugging //use toposort to build the final result - var topoSorted = TopologicalSort(allNodes, function); - List result = new(); - - foreach (var node in topoSorted) - { - result.Add(node.TypedCommandBind.CommandBind.Handler); - } - - return result; - } - - //Adapted from https://stackoverflow.com/a/24058279 - private static IEnumerable TopologicalSort(IEnumerable nodes, BoundKeyFunction function) - { - var elems = nodes.ToDictionary(node => node, - node => new HashSet(node.DependsOn)); - while (elems.Count > 0) - { - var elem = - elems.FirstOrDefault(x => x.Value.Count == 0); - if (elem.Key == null) - { - throw new InvalidOperationException("Found circular dependency when resolving" + - $" command binding handler order for key function {function.FunctionName}." + - $" Please check the systems which register bindings for" + - $" this function and eliminate the circular dependency."); - } - elems.Remove(elem.Key); - foreach (var selem in elems) - { - selem.Value.Remove(elem.Key); - } - yield return elem.Key; - } - } - - /// - /// node in our temporary dependency graph - /// - private class GraphNode - { - public List DependsOn = new(); - public readonly TypedCommandBind TypedCommandBind; - - public GraphNode(TypedCommandBind typedCommandBind) - { - TypedCommandBind = typedCommandBind; - } + return TopologicalSort.Sort(allNodes).Select(c => c.CommandBind.Handler).ToList(); } /// diff --git a/Robust.Shared/Utility/TopologicalSort.cs b/Robust.Shared/Utility/TopologicalSort.cs new file mode 100644 index 000000000..33ede15bc --- /dev/null +++ b/Robust.Shared/Utility/TopologicalSort.cs @@ -0,0 +1,125 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +namespace Robust.Shared.Utility +{ + public sealed class TopologicalSort + { + public static IEnumerable Sort(IEnumerable> nodes) + { + var totalVerts = 0; + var empty = new Queue>(); + + var nodesArray = nodes.ToArray(); + + foreach (var node in nodesArray) + { + totalVerts += 1; + + foreach (var dep in node.Dependant) + { + dep.DependsOnCount += 1; + } + } + + foreach (var node in nodesArray) + { + if (node.DependsOnCount == 0) + empty.Enqueue(node); + } + + while (empty.TryDequeue(out var node)) + { + yield return node.Value; + totalVerts -= 1; + + foreach (var dep in node.Dependant) + { + dep.DependsOnCount -= 1; + if (dep.DependsOnCount == 0) + empty.Enqueue(dep); + } + } + + if (totalVerts != 0) + throw new InvalidOperationException("Graph contained cycle(s)."); + } + + // I will never stop using the word "datum". + public static IEnumerable> FromBeforeAfter( + IEnumerable data, + Func keySelector, + Func> beforeSelector, + Func> afterSelector, + bool allowMissing = false) + where TValue : notnull + { + return FromBeforeAfter(data, keySelector, keySelector, afterSelector, beforeSelector, allowMissing); + } + + public static IEnumerable> FromBeforeAfter( + IEnumerable data, + Func keySelector, + Func valueSelector, + Func> beforeSelector, + Func> afterSelector, + bool allowMissing=false) + where TKey : notnull + { + var dict = new Dictionary node)>(); + + foreach (var datum in data) + { + var key = keySelector(datum); + var value = valueSelector(datum); + dict.Add(key, (datum, new GraphNode(value))); + } + + foreach (var (key, (datum, node)) in dict) + { + foreach (var before in beforeSelector(datum)) + { + if (dict.TryGetValue(before, out var entry)) + { + node.Dependant.Add(entry.node); + } + else if (!allowMissing) + { + throw new InvalidOperationException($"Vertex '{before}' referenced by '{key}' was not found in the graph."); + } + } + + foreach (var after in afterSelector(datum)) + { + if (dict.TryGetValue(after, out var entry)) + { + entry.node.Dependant.Add(node); + } + else if (!allowMissing) + { + throw new InvalidOperationException($"Vertex '{after}' referenced by '{key}' was not found in the graph."); + } + } + } + + return dict.Values.Select(c => c.node); + } + + [DebuggerDisplay("GraphNode: {" + nameof(Value) + "}")] + public class GraphNode + { + public readonly T Value; + public readonly List> Dependant = new(); + + // Used internal by sort implementation, do not touch. + internal int DependsOnCount; + + public GraphNode(T value) + { + Value = value; + } + } + } +} diff --git a/Robust.UnitTesting/Shared/GameObjects/EntityEventBusTests.ComponentEvent.cs b/Robust.UnitTesting/Shared/GameObjects/EntityEventBusTests.ComponentEvent.cs index c0e3116df..d4f5b1497 100644 --- a/Robust.UnitTesting/Shared/GameObjects/EntityEventBusTests.ComponentEvent.cs +++ b/Robust.UnitTesting/Shared/GameObjects/EntityEventBusTests.ComponentEvent.cs @@ -166,11 +166,98 @@ namespace Robust.UnitTesting.Shared.GameObjects } } + [Test] + public void CompEventOrdered() + { + // Arrange + var entUid = new EntityUid(7); + + var entManMock = new Mock(); + + var compManMock = new Mock(); + var compFacMock = new Mock(); + + void Setup(out T instance) where T : IComponent, new() + { + IComponent? inst = instance = new T(); + var reg = new Mock(); + reg.Setup(m => m.References).Returns(new Type[] {typeof(T)}); + + compFacMock.Setup(m => m.GetRegistration(typeof(T))).Returns(reg.Object); + compManMock.Setup(m => m.TryGetComponent(entUid, typeof(T), out inst)).Returns(true); + compManMock.Setup(m => m.GetComponent(entUid, typeof(T))).Returns(inst); + } + + Setup(out var instA); + Setup(out var instB); + Setup(out var instC); + + compManMock.Setup(m => m.ComponentFactory).Returns(compFacMock.Object); + entManMock.Setup(m => m.ComponentManager).Returns(compManMock.Object); + var bus = new EntityEventBus(entManMock.Object); + + // Subscribe + var a = false; + var b = false; + var c = false; + + void HandlerA(EntityUid uid, Component comp, TestEvent ev) + { + Assert.That(b, Is.False, "A should run before B"); + Assert.That(c, Is.False, "A should run before C"); + + a = true; + } + + void HandlerB(EntityUid uid, Component comp, TestEvent ev) + { + Assert.That(c, Is.True, "B should run after C"); + b = true; + } + + void HandlerC(EntityUid uid, Component comp, TestEvent ev) => c = true; + + bus.SubscribeLocalEvent(HandlerA, typeof(OrderComponentA), before: new []{typeof(OrderComponentB), typeof(OrderComponentC)}); + bus.SubscribeLocalEvent(HandlerB, typeof(OrderComponentB), after: new []{typeof(OrderComponentC)}); + bus.SubscribeLocalEvent(HandlerC, typeof(OrderComponentC)); + + // add a component to the system + entManMock.Raise(m=>m.EntityAdded += null, entManMock.Object, entUid); + compManMock.Raise(m => m.ComponentAdded += null, new AddedComponentEventArgs(instA, entUid)); + compManMock.Raise(m => m.ComponentAdded += null, new AddedComponentEventArgs(instB, entUid)); + compManMock.Raise(m => m.ComponentAdded += null, new AddedComponentEventArgs(instC, entUid)); + + // Raise + var evntArgs = new TestEvent(5); + bus.RaiseLocalEvent(entUid, evntArgs); + + // Assert + Assert.That(a, Is.True, "A did not fire"); + Assert.That(b, Is.True, "B did not fire"); + Assert.That(c, Is.True, "C did not fire"); + } + private class DummyComponent : Component { public override string Name => "Dummy"; } + private class OrderComponentA : Component + { + public override string Name => "OrderComponentA"; + } + + private class OrderComponentB : Component + { + public override string Name => "OrderComponentB"; + } + + private class OrderComponentC : Component + { + public override string Name => "OrderComponentC"; + } + + private class TestEvent : EntityEventArgs { public int TestNumber { get; } diff --git a/Robust.UnitTesting/Shared/GameObjects/EntityEventBusTests.SystemEvent.cs b/Robust.UnitTesting/Shared/GameObjects/EntityEventBusTests.SystemEvent.cs index 7db5333c2..2d72be9ae 100644 --- a/Robust.UnitTesting/Shared/GameObjects/EntityEventBusTests.SystemEvent.cs +++ b/Robust.UnitTesting/Shared/GameObjects/EntityEventBusTests.SystemEvent.cs @@ -469,6 +469,58 @@ namespace Robust.UnitTesting.Shared.GameObjects // Assert Assert.Throws(Code); } + + [Test] + public void RaiseEvent_Ordered() + { + // Arrange + var bus = BusFactory(); + + // Expected order is A -> C -> B + var a = false; + var b = false; + var c = false; + + void HandlerA(TestEventArgs ev) + { + Assert.That(b, Is.False, "A should run before B"); + Assert.That(c, Is.False, "A should run before C"); + + a = true; + } + + void HandlerB(TestEventArgs ev) + { + Assert.That(c, Is.True, "B should run after C"); + b = true; + } + + void HandlerC(TestEventArgs ev) => c = true; + + bus.SubscribeEvent(EventSource.Local, new SubA(), HandlerA, typeof(SubA), before: new []{typeof(SubB), typeof(SubC)}); + bus.SubscribeEvent(EventSource.Local, new SubB(), HandlerB, typeof(SubB), after: new []{typeof(SubC)}); + bus.SubscribeEvent(EventSource.Local, new SubC(), HandlerC, typeof(SubC)); + + // Act + bus.RaiseEvent(EventSource.Local, new TestEventArgs()); + + // Assert + Assert.That(a, Is.True, "A did not fire"); + Assert.That(b, Is.True, "B did not fire"); + Assert.That(c, Is.True, "C did not fire"); + } + + public sealed class SubA : IEntityEventSubscriber + { + } + + public sealed class SubB : IEntityEventSubscriber + { + } + + public sealed class SubC : IEntityEventSubscriber + { + } } internal class TestEventSubscriber : IEntityEventSubscriber { } diff --git a/Robust.UnitTesting/Shared/Input/Binding/CommandBindRegistry_Test.cs b/Robust.UnitTesting/Shared/Input/Binding/CommandBindRegistry_Test.cs index 7af861392..a28dd6798 100644 --- a/Robust.UnitTesting/Shared/Input/Binding/CommandBindRegistry_Test.cs +++ b/Robust.UnitTesting/Shared/Input/Binding/CommandBindRegistry_Test.cs @@ -222,12 +222,12 @@ namespace Robust.UnitTesting.Shared.Input.Binding .Bind(bkf, bHandler1) .Bind(bkf, bHandler2) .Register(registry); + CommandBinds.Builder + .Bind(bkf, cHandler1) + .BindAfter(bkf, cHandler2, typeof(TypeA)) + .Register(registry); - Assert.Throws(() => - CommandBinds.Builder - .Bind(bkf, cHandler1) - .BindAfter(bkf, cHandler2, typeof(TypeA)) - .Register(registry)); + Assert.Throws(registry.RebuildGraph); } } }