Make auto comp states infer when data should be cloned (#4461)

This commit is contained in:
DrSmugleaf
2023-09-29 22:14:10 -07:00
committed by GitHub
parent 3b6adeb5ff
commit 4818c3aab4
4 changed files with 58 additions and 45 deletions

View File

@@ -6,6 +6,7 @@ using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Text;
using static Microsoft.CodeAnalysis.SymbolDisplayFormat;
namespace Robust.Shared.CompNetworkGenerator
{
@@ -14,15 +15,23 @@ namespace Robust.Shared.CompNetworkGenerator
{
private const string ClassAttributeName = "Robust.Shared.Analyzers.AutoGenerateComponentStateAttribute";
private const string MemberAttributeName = "Robust.Shared.Analyzers.AutoNetworkedFieldAttribute";
private const string GlobalEntityUidName = "global::Robust.Shared.GameObjects.EntityUid";
private const string GlobalNullableEntityUidName = "global::Robust.Shared.GameObjects.EntityUid?";
private const string GlobalEntityCoordinatesName = "global::Robust.Shared.Map.EntityCoordinates";
private const string GlobalNullableEntityCoordinatesName = "global::Robust.Shared.Map.EntityCoordinates?";
private const string GlobalEntityUidSetName = "global::System.Collections.Generic.HashSet<global::Robust.Shared.GameObjects.EntityUid>";
private const string GlobalNetEntityUidSetName = "global::System.Collections.Generic.HashSet<global::Robust.Shared.GameObjects.NetEntity>";
private const string GlobalEntityUidListName = "global::System.Collections.Generic.List<global::Robust.Shared.GameObjects.EntityUid>";
private const string GlobalNetEntityUidListName = "global::System.Collections.Generic.List<global::Robust.Shared.GameObjects.NetEntity>";
private const string GlobalDictionaryName = "global::System.Collections.Generic.Dictionary<TKey, TValue>";
private const string GlobalHashSetName = "global::System.Collections.Generic.HashSet<T>";
private const string GlobalListName = "global::System.Collections.Generic.List<T>";
private static string GenerateSource(in GeneratorExecutionContext context, INamedTypeSymbol classSymbol, CSharpCompilation comp, bool raiseAfterAutoHandle)
{
var nameSpace = classSymbol.ContainingNamespace.ToDisplayString();
@@ -30,7 +39,7 @@ namespace Robust.Shared.CompNetworkGenerator
var stateName = $"{componentName}_AutoState";
var members = classSymbol.GetMembers();
var fields = new List<(ITypeSymbol Type, string FieldName, AttributeData Attribute)>();
var fields = new List<(ITypeSymbol Type, string FieldName)>();
var fieldAttr = comp.GetTypeByMetadataName(MemberAttributeName);
foreach (var mem in members)
@@ -47,7 +56,7 @@ namespace Robust.Shared.CompNetworkGenerator
switch (mem)
{
case IFieldSymbol field:
fields.Add((field.Type, field.Name, attribute));
fields.Add((field.Type, field.Name));
break;
case IPropertySymbol prop:
{
@@ -83,7 +92,7 @@ namespace Robust.Shared.CompNetworkGenerator
continue;
}
fields.Add((prop.Type, prop.Name, attribute));
fields.Add((prop.Type, prop.Name));
break;
}
}
@@ -121,9 +130,9 @@ namespace Robust.Shared.CompNetworkGenerator
// component.Count = state.Count;
var handleStateSetters = new StringBuilder();
foreach (var (type, name, attribute) in fields)
foreach (var (type, name) in fields)
{
var typeDisplayStr = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var typeDisplayStr = type.ToDisplayString(FullyQualifiedFormat);
var nullable = type.NullableAnnotation == NullableAnnotation.Annotated;
var nullableAnnotation = nullable ? "?" : string.Empty;
@@ -150,31 +159,6 @@ namespace Robust.Shared.CompNetworkGenerator
handleStateSetters.Append($@"
component.{name} = EnsureCoordinates<{componentName}>(state.{name}, uid);");
break;
default:
stateFields.Append($@"
public {typeDisplayStr} {name} = default!;");
if (attribute.ConstructorArguments[0].Value is bool val && val)
{
// get first ctor arg of the field attribute, which determines whether the field should be cloned
// (like if its a dict or list)
getStateInit.Append($@"
{name} = component.{name},");
handleStateSetters.Append($@"
if (state.{name} != null)
component.{name} = new(state.{name});");
}
else
{
getStateInit.Append($@"
{name} = component.{name},");
handleStateSetters.Append($@"
component.{name} = state.{name};");
}
break;
case GlobalEntityUidSetName:
stateFields.Append($@"
@@ -195,6 +179,33 @@ namespace Robust.Shared.CompNetworkGenerator
handleStateSetters.Append($@"
component.{name} = EnsureEntityList<{componentName}>(state.{name}, uid);");
break;
default:
stateFields.Append($@"
public {typeDisplayStr} {name} = default!;");
if (IsCloneType(type))
{
// get first ctor arg of the field attribute, which determines whether the field should be cloned
// (like if its a dict or list)
getStateInit.Append($@"
{name} = component.{name},");
handleStateSetters.Append($@"
if (state.{name} == null)
component.{name} = null;
else
component.{name} = new(state.{name});");
}
else
{
getStateInit.Append($@"
{name} = component.{name},");
handleStateSetters.Append($@"
component.{name} = state.{name};");
}
break;
}
}
@@ -353,5 +364,20 @@ public partial class {componentName}
}
context.RegisterForSyntaxNotifications(() => new NameReferenceSyntaxReceiver());
}
private static bool IsCloneType(ITypeSymbol type)
{
if (type is not INamedTypeSymbol named || !named.IsGenericType)
{
return false;
}
var constructed = named.ConstructedFrom.ToDisplayString(FullyQualifiedFormat);
return constructed switch
{
GlobalDictionaryName or GlobalHashSetName or GlobalListName => true,
_ => false
};
}
}
}

View File

@@ -2,6 +2,7 @@
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<LangVersion>9</LangVersion>
</PropertyGroup>
<ItemGroup>