Files
RobustToolbox/Robust.Shared/IoC/DependencyCollection.cs
T
Pieter-Jan Briers b4eb85ad3c [Dependency] source generator (#6549)
* [Dependency] source generator

No more reflection, no more codegen at runtime

Also various changes to Roslyn helpers to make this easier to write.

Requires all types with dependencies to be partial and not have readonly dependency fields. An analyzer enforces this at warning level, the previous injection strategies have remained in the code *for now* as a fallback.

No fallback is available for [field: Dependency] properties, due to a Roslyn bug.

Code Fixes exist. We love Roslyn

* Release notes

* Handle nullable dependencies

These are bad but gotta deal with it.

* Apply suggestions from code review

Co-authored-by: Moony <moony@hellomouse.net>

* Fine, let's not use collection expressions

---------

Co-authored-by: Moony <moony@hellomouse.net>
2026-05-08 12:38:02 +02:00

740 lines
27 KiB
C#

using System;
using System.Collections.Concurrent;
using System.Collections.Frozen;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.CompilerServices;
using System.Threading;
using JetBrains.Annotations;
using Robust.Shared.IoC.Exceptions;
using Robust.Shared.Utility;
using NotNull = System.Diagnostics.CodeAnalysis.NotNullAttribute;
namespace Robust.Shared.IoC
{
public delegate T DependencyFactoryDelegate<out T>()
where T : class;
public delegate T DependencyFactoryBaseGenericLazyDelegate<out T>(
Type type,
IDependencyCollection services)
where T : class;
/// <inheritdoc />
internal sealed class DependencyCollection : IDependencyCollection
{
private delegate void InjectorDelegate(object target, object[] services);
private delegate T DependencyFactoryDelegateInternal<out T>(
IReadOnlyDictionary<Type, object> services)
where T : class;
private static readonly Type[] InjectorParameters = { typeof(object), typeof(object[]) };
// Temporary: cache of whether IHasDependencies is safe to use on a type.
// False for any types on which the source gen has failed to run.
private static readonly Dictionary<Type, bool> IsHasDependenciesSafe = new();
/// <summary>
/// Dictionary that maps the types passed to <see cref="Resolve{T}()"/> to their implementation.
/// This is the first dictionary to get hit on a resolve.
/// </summary>
/// <remarks>
/// Immutable and atomically swapped to provide thread safety guarantees.
/// </remarks>
private FrozenDictionary<Type, object> _services = FrozenDictionary<Type, object>.Empty;
/// <summary>
/// Dictionary that maps the types passed to <see cref="Resolve{T}()"/> to their implementation
/// for any types registered through <see cref="RegisterBaseGenericLazy"/>.
/// </summary>
private readonly ConcurrentDictionary<Type, object> _lazyServices = new();
// Start fields used for building new services.
/// <summary>
/// The types interface types mapping to their registered implementations.
/// This is pulled from to make a service if it doesn't exist yet.
/// </summary>
private readonly Dictionary<Type, Type> _resolveTypes = new();
private readonly Dictionary<Type, DependencyFactoryDelegateInternal<object>> _resolveFactories = new();
private readonly Queue<Type> _pendingResolves = new();
private readonly ConcurrentDictionary<Type, DependencyFactoryBaseGenericLazyDelegate<object>> _baseGenericLazyFactories = new();
private readonly object _serviceBuildLock = new();
// End fields for building new services.
// To do injection of common types like components, we make DynamicMethods to do the actual injecting.
// This is way faster than reflection and should be allocation free outside setup.
private readonly Dictionary<Type, CachedInjector> _injectorCache =
new();
private readonly ReaderWriterLockSlim _injectorCacheLock = new();
private readonly IDependencyCollection? _parentCollection;
public DependencyCollection()
{
}
public DependencyCollection(IDependencyCollection parentCollection)
{
_parentCollection = parentCollection;
}
public IDependencyCollection FromParent(IDependencyCollection parentCollection)
{
return new DependencyCollection(parentCollection);
}
/// <inheritdoc />
public IEnumerable<Type> GetRegisteredTypes()
{
return _parentCollection != null
? _services.Keys.Concat(_lazyServices.Keys).Concat(_parentCollection.GetRegisteredTypes())
: _services.Keys.Concat(_lazyServices.Keys);
}
public Type[] GetCachedInjectorTypes()
{
using var _ = _injectorCacheLock.ReadGuard();
return _injectorCache.Where(kv => kv.Value.Delegate != null).Select(kv => kv.Key).ToArray();
}
/// <inheritdoc />
public bool TryResolveType<T>([NotNullWhen(true)] out T? instance)
{
if (TryResolveType(typeof(T), out object? rawInstance))
{
if (rawInstance is T typedInstance)
{
instance = typedInstance;
return true;
}
}
instance = default;
return false;
}
/// <inheritdoc />
public bool TryResolveType(Type objectType, [MaybeNullWhen(false)] out object instance)
{
return TryResolveType(objectType, _services, out instance);
}
private bool TryResolveType(
Type objectType,
FrozenDictionary<Type, object> services,
[MaybeNullWhen(false)] out object instance)
{
return TryResolveType(objectType, (IReadOnlyDictionary<Type, object>) services, out instance);
}
private bool TryResolveType(
Type objectType,
IReadOnlyDictionary<Type, object> services,
[MaybeNullWhen(false)] out object instance)
{
if (!services.TryGetValue(objectType, out instance))
{
if (objectType.IsGenericType &&
_baseGenericLazyFactories.TryGetValue(objectType.GetGenericTypeDefinition(), out var factory))
{
instance = _lazyServices.GetOrAdd(objectType, type => factory(type, this));
return true;
}
return _parentCollection is not null && _parentCollection.TryResolveType(objectType, out instance);
}
return true;
}
/// <inheritdoc />
public void Register<TInterface, TImplementation>(bool overwrite = false)
where TImplementation : class, TInterface
where TInterface : class
{
Register<TInterface, TImplementation>(services =>
{
var objectType = typeof(TImplementation);
var constructors =
objectType.GetConstructors(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
if (constructors.Length != 1)
throw new InvalidOperationException(
$"Dependency '{typeof(TImplementation).FullName}' requires exactly one constructor.");
var chosenConstructor = constructors[0];
var constructorParams = constructors[0].GetParameters();
var parameters = new object[constructorParams.Length];
for (var index = 0; index < constructorParams.Length; index++)
{
var param = constructorParams[index];
if (TryResolveType(param.ParameterType, services, out var instance))
{
parameters[index] = instance;
}
else
{
if (_resolveTypes.ContainsKey(param.ParameterType))
{
throw new InvalidOperationException(
$"Dependency '{typeof(TImplementation).FullName}' ctor requires {param.ParameterType.FullName} registered before it.");
}
throw new InvalidOperationException(
$"Dependency '{typeof(TImplementation).FullName}' ctor has unknown dependency {param.ParameterType.FullName}");
}
}
return (TImplementation)chosenConstructor.Invoke(parameters);
}, overwrite);
}
/// <inheritdoc />
public void Register<TInterface, TImplementation>(
DependencyFactoryDelegate<TImplementation> factory,
bool overwrite = false)
where TImplementation : class, TInterface
where TInterface : class
{
Register(typeof(TInterface), typeof(TImplementation), factory, overwrite);
}
private void Register<TInterface, TImplementation>(
DependencyFactoryDelegateInternal<TImplementation> factory,
bool overwrite = false)
where TImplementation : class, TInterface
where TInterface : class
{
Register(typeof(TInterface), typeof(TImplementation), factory, overwrite);
}
/// <inheritdoc />
public void Register(
Type implementation,
DependencyFactoryDelegate<object>? factory = null,
bool overwrite = false)
{
Register(implementation, implementation, factory, overwrite);
}
public void Register(
Type interfaceType,
Type implementation,
DependencyFactoryDelegate<object>? factory = null,
bool overwrite = false)
{
Register(interfaceType, implementation, FactoryToInternal(factory), overwrite);
}
private void Register(
Type interfaceType,
Type implementation,
DependencyFactoryDelegateInternal<object>? factory = null,
bool overwrite = false)
{
CheckRegisterInterface(interfaceType, implementation, overwrite);
object DefaultFactory(IReadOnlyDictionary<Type, object> services)
{
var constructors =
implementation.GetConstructors(BindingFlags.Public | BindingFlags.NonPublic |
BindingFlags.Instance);
if (constructors.Length != 1)
throw new InvalidOperationException(
$"Dependency '{implementation.FullName}' requires exactly one constructor.");
var chosenConstructor = constructors[0];
var constructorParams = chosenConstructor.GetParameters();
var parameters = new object[constructorParams.Length];
for (var index = 0; index < constructorParams.Length; index++)
{
var param = constructorParams[index];
if (TryResolveType(param.ParameterType, services, out var instance))
{
parameters[index] = instance;
}
else
{
if (_resolveTypes.ContainsKey(param.ParameterType))
{
throw new InvalidOperationException(
$"Dependency '{implementation.FullName}' ctor requires {param.ParameterType.FullName} registered before it.");
}
throw new InvalidOperationException(
$"Dependency '{implementation.FullName}' ctor has unknown dependency {param.ParameterType.FullName}");
}
}
return chosenConstructor.Invoke(parameters);
}
lock (_serviceBuildLock)
{
_resolveTypes[interfaceType] = implementation;
_resolveFactories[implementation] = factory ?? DefaultFactory;
_pendingResolves.Enqueue(interfaceType);
}
}
private void CheckRegisterInterface(Type interfaceType, Type implementationType, bool overwrite)
{
lock (_serviceBuildLock)
{
if (!_resolveTypes.ContainsKey(interfaceType))
return;
if (!overwrite)
{
throw new InvalidOperationException
(
string.Format(
"Attempted to register already registered interface {0}. New implementation: {1}, Old implementation: {2}",
interfaceType, implementationType, _resolveTypes[interfaceType]
));
}
if (_services.ContainsKey(interfaceType))
{
throw new InvalidOperationException(
$"Attempted to overwrite already instantiated interface {interfaceType}.");
}
}
}
public void RegisterInstance<TInterface>(object implementation, bool overwrite = false)
where TInterface : class
{
RegisterInstance(typeof(TInterface), implementation, overwrite);
}
/// <inheritdoc />
public void RegisterInstance(Type type, object implementation, bool overwrite = false)
{
if (implementation == null)
throw new ArgumentNullException(nameof(implementation));
if (!implementation.GetType().IsAssignableTo(type))
throw new InvalidOperationException(
$"Implementation type {implementation.GetType()} is not assignable to type {type}");
Register(type, implementation.GetType(), () => implementation, overwrite);
}
public void RegisterBaseGenericLazy(Type interfaceType, DependencyFactoryBaseGenericLazyDelegate<object> factory)
{
lock (_serviceBuildLock)
{
_baseGenericLazyFactories[interfaceType] = factory;
}
}
/// <inheritdoc />
public void Clear()
{
foreach (var service in _services.Values.Concat(_lazyServices.Values).OfType<IDisposable>().Distinct())
{
service.Dispose();
}
_services = FrozenDictionary<Type, object>.Empty;
_lazyServices.Clear();
lock (_serviceBuildLock)
{
_resolveTypes.Clear();
_resolveFactories.Clear();
_pendingResolves.Clear();
}
using (_injectorCacheLock.WriteGuard())
{
_injectorCache.Clear();
}
}
/// <inheritdoc />
[System.Diagnostics.Contracts.Pure]
public T Resolve<T>()
{
return (T)ResolveType(typeof(T));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Resolve<T>([NotNull] ref T? instance)
{
// Resolve<T>() will either throw or return a concrete instance, therefore we suppress the nullable warning.
instance ??= Resolve<T>()!;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Resolve<T1, T2>([NotNull] ref T1? instance1, [NotNull] ref T2? instance2)
{
Resolve(ref instance1);
Resolve(ref instance2);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Resolve<T1, T2, T3>([NotNull] ref T1? instance1, [NotNull] ref T2? instance2,
[NotNull] ref T3? instance3)
{
Resolve(ref instance1, ref instance2);
Resolve(ref instance3);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Resolve<T1, T2, T3, T4>([NotNull] ref T1? instance1, [NotNull] ref T2? instance2,
[NotNull] ref T3? instance3, [NotNull] ref T4? instance4)
{
Resolve(ref instance1, ref instance2);
Resolve(ref instance3, ref instance4);
}
/// <inheritdoc />
[System.Diagnostics.Contracts.Pure]
public object ResolveType(Type type)
{
if (TryResolveType(type, out var value))
{
return value;
}
if (_resolveTypes.ContainsKey(type))
{
// If we have the type registered but not created that means we haven't been told to initialize the graph yet.
throw new InvalidOperationException(
$"Attempted to resolve type {type} before the object graph for it has been populated.");
}
if (type == typeof(IDependencyCollection))
{
return this;
}
throw new UnregisteredTypeException(type);
}
/// <inheritdoc />
public void BuildGraph()
{
lock (_serviceBuildLock)
{
// List of all objects we need to inject dependencies into.
var injectList = new List<object>();
var newDeps = _services.ToDictionary();
// First we build every type we have registered but isn't yet built.
// This allows us to run this after the content assembly has been loaded.
while (_pendingResolves.Count > 0)
{
Type key = _pendingResolves.Dequeue();
var value = _resolveTypes[key];
// Find a potential dupe by checking other registered types that have already been instantiated that have the same instance type.
// Can't catch ourselves because we're not instantiated.
// Ones that aren't yet instantiated are about to be and will find us instead.
var (type, _) =
_resolveTypes.FirstOrDefault(p => newDeps.ContainsKey(p.Key) && p.Value == value)!;
// Interface key can't be null so since KeyValuePair<> is a struct,
// this effectively checks whether we found something.
if (type != null)
{
// We have something with the same instance type, use that.
newDeps[key] = newDeps[type];
continue;
}
try
{
// Yay for delegate covariance
object instance = _resolveFactories[value].Invoke(newDeps);
newDeps[key] = instance;
injectList.Add(instance);
}
catch (TargetInvocationException e)
{
throw new ImplementationConstructorException(value, e.InnerException);
}
}
// Because we only ever construct an instance once per registration, there is no need to keep the factory
// delegates. Also we need to free the delegates because lambdas capture variables.
_resolveFactories.Clear();
// Atomically set the new dict of services.
_services = newDeps.ToFrozenDictionary();
// Graph built, go over ones that need injection.
foreach (var implementation in injectList)
{
InjectDependenciesReflection(implementation);
}
foreach (var injectedItem in injectList.OfType<IPostInjectInit>())
{
injectedItem.PostInject();
}
}
}
/// <inheritdoc />
public void InjectDependencies(object obj, bool oneOff = false)
{
var type = obj.GetType();
bool found;
CachedInjector injector;
using (_injectorCacheLock.ReadGuard())
{
found = _injectorCache.TryGetValue(type, out injector);
}
if (!found)
{
if (oneOff)
{
// If this is a one-off injection then use the old reflection method.
// Won't cache a bunch of later-unused stuff.
InjectDependenciesReflection(obj);
return;
}
injector = CacheInjector(obj, type);
}
var (@delegate, hasDependencies, services) = injector;
if (services?.Length == 0)
return;
if (hasDependencies)
{
((IHasDependencies)obj).Inject(services);
}
// If @delegate is null then the type has no dependencies.
// So running an initializer would be quite wasteful.
@delegate?.Invoke(obj, services!);
}
private object ResolveForInjection(Type owningType, Type fieldType, FrozenDictionary<Type, object> services)
{
// Not using Resolve<T>() because we're literally building it right now.
if (TryResolveType(fieldType, services, out var dep))
{
// Quick note: this DOES work with read only fields, though it may be a CLR implementation detail.
return dep;
}
// A hard-coded special case so the DependencyCollection can inject itself.
// This is not put into the services so it can be overridden if needed.
if (fieldType == typeof(IDependencyCollection))
{
return this;
}
throw new UnregisteredDependencyException(owningType, fieldType);
}
private void InjectDependenciesReflection(object obj)
{
if (CalculateHasDependenciesSafe(obj.GetType()))
{
InjectImmediateHasDependencies(obj);
return;
}
var type = obj.GetType();
foreach (var field in type.GetAllFields())
{
if (!Attribute.IsDefined(field, typeof(DependencyAttribute)))
{
continue;
}
field.SetValue(obj, ResolveForInjection(type, field.FieldType, _services));
}
}
private void InjectImmediateHasDependencies(object obj)
{
if (obj is not IHasDependencies hasDependencies)
return;
var types = hasDependencies.GetDependencyTypes();
var services = ResolveServicesArray(obj.GetType(), types);
hasDependencies.Inject(services);
}
private CachedInjector CacheInjector(object obj, Type type)
{
using var _ = _injectorCacheLock.WriteGuard();
// Check in case value got filled in right before we acquired the lock.
if (_injectorCache.TryGetValue(type, out var cached))
return cached;
if (CalculateHasDependenciesSafe(type))
return CacheInjectorHasDependencies(obj, type);
var fields = new List<FieldInfo>();
foreach (var field in type.GetAllFields())
{
if (!Attribute.IsDefined(field, typeof(DependencyAttribute)))
{
continue;
}
fields.Add(field);
}
// No dependency fields, nothing to inject so no point setting this all up.
if (fields.Count == 0)
{
_injectorCache.Add(type, default);
return default;
}
var dynamicMethod = new DynamicMethod($"_injector<>{type}", null, InjectorParameters, type, true);
dynamicMethod.DefineParameter(1, ParameterAttributes.In, "target");
dynamicMethod.DefineParameter(2, ParameterAttributes.In, "services");
var i = 0;
var services = new List<object>();
var generator = dynamicMethod.GetILGenerator();
foreach (var field in fields)
{
// Load object to inject into.
generator.Emit(OpCodes.Ldarg_0);
// Not using Resolve<T>() because we're literally building it right now.
if (!TryResolveType(field.FieldType, out var service))
{
// A hard-coded special case so the DependencyCollection can inject itself.
// This is not put into the services so it can be overridden if needed.
if (field.FieldType == typeof(IDependencyCollection))
{
service = this;
}
else
{
throw new UnregisteredDependencyException(type, field.FieldType, field.Name);
}
}
services.Add(service);
// Load services array.
generator.Emit(OpCodes.Ldarg_1);
// Load service from array.
generator.Emit(OpCodes.Ldc_I4, i++);
generator.Emit(OpCodes.Ldelem_Ref);
// Set service into field.
generator.Emit(OpCodes.Stfld, field);
}
generator.Emit(OpCodes.Ret);
var @delegate = (InjectorDelegate)dynamicMethod.CreateDelegate(typeof(InjectorDelegate));
cached = new CachedInjector(@delegate, false, services.ToArray());
_injectorCache.Add(type, cached);
return cached;
}
private CachedInjector CacheInjectorHasDependencies(object obj, Type type)
{
DebugTools.Assert(type == obj.GetType());
if (obj is not IHasDependencies hasDeps)
return new CachedInjector(null, true, []);
var types = hasDeps.GetDependencyTypes();
var services = ResolveServicesArray(type, types);
return new CachedInjector(null, true, services);
}
private object[] ResolveServicesArray(Type owningType, Type[] types)
{
var result = new object[types.Length];
for (var i = 0; i < types.Length; i++)
{
result[i] = ResolveForInjection(owningType, types[i], _services);
}
return result;
}
[return: NotNullIfNotNull("factory")]
private static DependencyFactoryDelegateInternal<T>? FactoryToInternal<T>(
DependencyFactoryDelegate<T>? factory)
where T : class
{
if (factory == null)
return null;
return _ => factory();
}
private record struct CachedInjector(InjectorDelegate? Delegate, bool HasDependencies, object[]? Services);
private static bool HasAnyDependenciesAtLevel(Type type)
{
return type
.GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.DeclaredOnly)
.Any(field => field.HasCustomAttribute<DependencyAttribute>());
}
private static bool CalculateHasDependenciesSafe(Type type)
{
bool safe;
lock (IsHasDependenciesSafe)
{
if (IsHasDependenciesSafe.TryGetValue(type, out safe))
return safe;
safe = true;
for (var checkType = type; checkType != null; checkType = checkType.BaseType)
{
if (checkType.HasCustomAttribute<HasDependenciesGeneratedAttribute>())
continue;
if (HasAnyDependenciesAtLevel(checkType))
{
safe = false;
break;
}
}
IsHasDependenciesSafe.Add(type, safe);
}
return safe;
}
}
}