Optimize assembly type checking.

It's now parallelized which cuts off ~200ms on its own for me.
Config is now shared between multiple loads which saves a lot as well.

All in all, pretty good.
This commit is contained in:
Pieter-Jan Briers
2020-12-14 16:34:33 +01:00
parent c335170fc1
commit 7473b6dae1
4 changed files with 95 additions and 69 deletions

View File

@@ -11,9 +11,8 @@ using System.Reflection.Metadata;
using System.Reflection.PortableExecutable;
using System.Threading.Tasks;
using ILVerify;
using Pidgin;
using Robust.Shared.Interfaces.Log;
using Robust.Shared.Interfaces.Resources;
using Robust.Shared.Log;
using Robust.Shared.Utility;
// psst
@@ -35,19 +34,23 @@ namespace Robust.Shared.ContentPack
/// <summary>
/// Completely disables type checking, allowing everything.
/// </summary>
public bool DisableTypeCheck { get; set; } = false;
public bool DisableTypeCheck { get; init; }
public DumpFlags Dump { get; set; } = DumpFlags.None;
public bool VerifyIL { get; set; } = true;
public DumpFlags Dump { get; init; } = DumpFlags.None;
public bool VerifyIL { get; init; } = true;
private bool WouldNoOp => Dump == DumpFlags.None && DisableTypeCheck && !VerifyIL;
// Necessary for loads with launcher loader.
public Func<string, Stream?>? ExtraRobustLoader { get; set; }
public Func<string, Stream?>? ExtraRobustLoader { get; init; }
private readonly ISawmill _sawmill;
private readonly SandboxConfig _config;
public AssemblyTypeChecker(IResourceManager res)
public AssemblyTypeChecker(IResourceManager res, ISawmill sawmill)
{
_res = res;
_sawmill = sawmill;
_config = LoadConfig();
}
private Resolver CreateResolver()
@@ -57,16 +60,16 @@ namespace Robust.Shared.ContentPack
string[] loadDirs;
if (string.IsNullOrEmpty(ourPath))
{
Logger.DebugS("res.typecheck", "Robust directory not available");
_sawmill.Debug("Robust directory not available");
loadDirs = new[] {dotnetDir};
}
else
{
Logger.DebugS("res.typecheck", "Robust directory is {0}", ourPath);
_sawmill.Debug("Robust directory is {0}", ourPath);
loadDirs = new[] {dotnetDir, Path.GetDirectoryName(ourPath)!};
}
Logger.DebugS("res.typecheck", ".NET runtime directory is {0}", dotnetDir);
_sawmill.Debug(".NET runtime directory is {0}", dotnetDir);
return new Resolver(
this,
@@ -89,17 +92,18 @@ namespace Robust.Shared.ContentPack
return true;
}
Logger.DebugS("res.typecheck", "Checking assembly...");
_sawmill.Debug("Checking assembly...");
var fullStopwatch = Stopwatch.StartNew();
var config = LoadConfig();
var resolver = CreateResolver();
using var peReader = new PEReader(assembly, PEStreamOptions.LeaveOpen);
var reader = peReader.GetMetadataReader();
var asmName = reader.GetString(reader.GetAssemblyDefinition().Name);
if (VerifyIL)
{
if (!DoVerifyIL(resolver, config, peReader, reader))
if (!DoVerifyIL(asmName, resolver, peReader, reader))
{
return false;
}
@@ -110,13 +114,13 @@ namespace Robust.Shared.ContentPack
var types = GetReferencedTypes(reader, errors);
var members = GetReferencedMembers(reader, errors);
var inherited = GetExternalInheritedTypes(reader, errors);
Logger.DebugS("res.typecheck", $"References loaded... {fullStopwatch.ElapsedMilliseconds}ms");
_sawmill.Debug($"References loaded... {fullStopwatch.ElapsedMilliseconds}ms");
if ((Dump & DumpFlags.Types) != 0)
{
foreach (var mType in types)
{
Logger.DebugS("res.typecheck", $"RefType: {mType}");
_sawmill.Debug($"RefType: {mType}");
}
}
@@ -124,7 +128,7 @@ namespace Robust.Shared.ContentPack
{
foreach (var memberRef in members)
{
Logger.DebugS("res.typecheck", $"RefMember: {memberRef}");
_sawmill.Debug($"RefMember: {memberRef}");
}
}
@@ -132,10 +136,10 @@ namespace Robust.Shared.ContentPack
{
foreach (var (name, baseType, interfaces) in inherited)
{
Logger.DebugS("res.typecheck", $"Inherit: {name} -> {baseType}");
_sawmill.Debug($"Inherit: {name} -> {baseType}");
foreach (var @interface in interfaces)
{
Logger.DebugS("res.typecheck", $" Interface: {@interface}");
_sawmill.Debug($" Interface: {@interface}");
}
}
}
@@ -150,46 +154,49 @@ namespace Robust.Shared.ContentPack
// we won't have to check that any types in their type arguments are whitelisted.
foreach (var type in types)
{
if (!IsTypeAccessAllowed(type, config, out _))
if (!IsTypeAccessAllowed(type, out _))
{
errors.Add(new SandboxError($"Access to type not allowed: {type}"));
}
}
Logger.DebugS("res.typecheck", $"Types... {fullStopwatch.ElapsedMilliseconds}ms");
_sawmill.Debug($"Types... {fullStopwatch.ElapsedMilliseconds}ms");
CheckInheritance(inherited, errors, config);
CheckInheritance(inherited, errors);
Logger.DebugS("res.typecheck", $"Inheritance... {fullStopwatch.ElapsedMilliseconds}ms");
_sawmill.Debug($"Inheritance... {fullStopwatch.ElapsedMilliseconds}ms");
CheckMemberReferences(members, config, errors);
CheckMemberReferences(members, errors);
foreach (var error in errors)
{
Logger.ErrorS("res.typecheck", $"Sandbox violation: {error.Message}");
_sawmill.Error($"Sandbox violation: {error.Message}");
}
Logger.DebugS("res.typecheck", $"Checked assembly in {fullStopwatch.ElapsedMilliseconds}ms");
_sawmill.Debug($"Checked assembly in {fullStopwatch.ElapsedMilliseconds}ms");
return errors.IsEmpty;
}
private static bool DoVerifyIL(Resolver resolver, SandboxConfig config, PEReader peReader,
private bool DoVerifyIL(
string name,
IResolver resolver,
PEReader peReader,
MetadataReader reader)
{
Logger.DebugS("res.typecheck", "Verifying IL...");
_sawmill.Debug($"{name}: Verifying IL...");
var sw = Stopwatch.StartNew();
var ver = new Verifier(resolver);
ver.SetSystemModuleName(new AssemblyName(config.SystemAssemblyName));
ver.SetSystemModuleName(new AssemblyName(_config.SystemAssemblyName));
var verifyErrors = false;
foreach (var res in ver.Verify(peReader))
{
if (config.AllowedVerifierErrors.Contains(res.Code))
if (_config.AllowedVerifierErrors.Contains(res.Code))
{
continue;
}
var msg = $"ILVerify: {res.Message}";
var msg = $"{name}: ILVerify: {res.Message}";
try
{
@@ -213,14 +220,14 @@ namespace Robust.Shared.ContentPack
}
catch (UnsupportedMetadataException e)
{
Logger.ErrorS("res.typecheck", $"{e}");
_sawmill.Error($"{e}");
}
verifyErrors = true;
Logger.ErrorS("res.typecheck", msg);
_sawmill.Error(msg);
}
Logger.DebugS("res.typecheck", $"Verified IL in {sw.Elapsed.TotalMilliseconds}ms");
_sawmill.Debug($"{name}: Verified IL in {sw.Elapsed.TotalMilliseconds}ms");
if (verifyErrors)
{
@@ -230,7 +237,8 @@ namespace Robust.Shared.ContentPack
return true;
}
private static void CheckMemberReferences(List<MMemberRef> members, SandboxConfig config,
private void CheckMemberReferences(
List<MMemberRef> members,
ConcurrentBag<SandboxError> errors)
{
Parallel.ForEach(members, memberRef =>
@@ -261,7 +269,7 @@ namespace Robust.Shared.ContentPack
var baseTypeReferenced = (MTypeReferenced) baseType;
if (!IsTypeAccessAllowed(baseTypeReferenced, config, out var typeCfg))
if (!IsTypeAccessAllowed(baseTypeReferenced, out var typeCfg))
{
// Technically this error isn't necessary since we have an earlier pass
// checking all referenced types. That should have caught this
@@ -325,9 +333,9 @@ namespace Robust.Shared.ContentPack
});
}
private static void CheckInheritance(
private void CheckInheritance(
List<(MType type, MType parent, ArraySegment<MType> interfaceImpls)> inherited,
ConcurrentBag<SandboxError> errors, SandboxConfig config)
ConcurrentBag<SandboxError> errors)
{
// This inheritance whitelisting primarily serves to avoid content doing funny stuff
// by e.g. inheriting Type.
@@ -355,7 +363,7 @@ namespace Robust.Shared.ContentPack
_ => throw new InvalidOperationException() // Can't happen.
};
if (!IsTypeAccessAllowed(realBaseType, config, out var cfg))
if (!IsTypeAccessAllowed(realBaseType, out var cfg))
{
return false;
}
@@ -365,14 +373,13 @@ namespace Robust.Shared.ContentPack
}
}
private static bool IsTypeAccessAllowed(MTypeReferenced type, SandboxConfig config,
[NotNullWhen(true)] out TypeConfig? cfg)
private bool IsTypeAccessAllowed(MTypeReferenced type, [NotNullWhen(true)] out TypeConfig? cfg)
{
if (type.Namespace == null)
{
if (type.ResolutionScope is MResScopeType parentType)
{
if (!IsTypeAccessAllowed((MTypeReferenced) parentType.Type, config, out var parentCfg))
if (!IsTypeAccessAllowed((MTypeReferenced) parentType.Type, out var parentCfg))
{
cfg = null;
return false;
@@ -402,7 +409,7 @@ namespace Robust.Shared.ContentPack
}
// Check if in whitelisted namespaces.
foreach (var whNamespace in config.WhitelistedNamespaces)
foreach (var whNamespace in _config.WhitelistedNamespaces)
{
if (type.Namespace.StartsWith(whNamespace))
{
@@ -411,7 +418,7 @@ namespace Robust.Shared.ContentPack
}
}
if (!config.Types.TryGetValue(type.Namespace, out var nsDict))
if (!_config.Types.TryGetValue(type.Namespace, out var nsDict))
{
cfg = null;
return false;