diff --git a/Content.Server.Database/Model.cs b/Content.Server.Database/Model.cs index f107389eb1..b27a07c7b8 100644 --- a/Content.Server.Database/Model.cs +++ b/Content.Server.Database/Model.cs @@ -80,7 +80,6 @@ namespace Content.Server.Database } } - [Table("preference")] public class Preference { // NOTE: on postgres there SHOULD be an FK ensuring that the selected character slot always exists. @@ -88,49 +87,46 @@ namespace Content.Server.Database // Because if I let EFCore know about it it would explode on a circular reference. // Also it has to be DEFERRABLE INITIALLY DEFERRED so that insertion of new preferences works. // Also I couldn't figure out how to create it on SQLite. - - [Column("preference_id")] public int Id { get; set; } - [Column("user_id")] public Guid UserId { get; set; } - [Column("selected_character_slot")] public int SelectedCharacterSlot { get; set; } - [Column("admin_ooc_color")] public string AdminOOCColor { get; set; } = null!; + public int Id { get; set; } + public Guid UserId { get; set; } + public int SelectedCharacterSlot { get; set; } + public string AdminOOCColor { get; set; } = null!; public List Profiles { get; } = new(); } - [Table("profile")] public class Profile { - [Column("profile_id")] public int Id { get; set; } - [Column("slot")] public int Slot { get; set; } + public int Id { get; set; } + public int Slot { get; set; } [Column("char_name")] public string CharacterName { get; set; } = null!; - [Column("age")] public int Age { get; set; } - [Column("sex")] public string Sex { get; set; } = null!; - [Column("gender")] public string Gender { get; set; } = null!; - [Column("hair_name")] public string HairName { get; set; } = null!; - [Column("hair_color")] public string HairColor { get; set; } = null!; - [Column("facial_hair_name")] public string FacialHairName { get; set; } = null!; - [Column("facial_hair_color")] public string FacialHairColor { get; set; } = null!; - [Column("eye_color")] public string EyeColor { get; set; } = null!; - [Column("skin_color")] public string SkinColor { get; set; } = null!; - [Column("clothing")] public string Clothing { get; set; } = null!; - [Column("backpack")] public string Backpack { get; set; } = null!; + public int Age { get; set; } + public string Sex { get; set; } = null!; + public string Gender { get; set; } = null!; + public string HairName { get; set; } = null!; + public string HairColor { get; set; } = null!; + public string FacialHairName { get; set; } = null!; + public string FacialHairColor { get; set; } = null!; + public string EyeColor { get; set; } = null!; + public string SkinColor { get; set; } = null!; + public string Clothing { get; set; } = null!; + public string Backpack { get; set; } = null!; public List Jobs { get; } = new(); public List Antags { get; } = new(); [Column("pref_unavailable")] public DbPreferenceUnavailableMode PreferenceUnavailable { get; set; } - [Column("preference_id")] public int PreferenceId { get; set; } + public int PreferenceId { get; set; } public Preference Preference { get; set; } = null!; } - [Table("job")] public class Job { - [Column("job_id")] public int Id { get; set; } + public int Id { get; set; } public Profile Profile { get; set; } = null!; - [Column("profile_id")] public int ProfileId { get; set; } + public int ProfileId { get; set; } - [Column("job_name")] public string JobName { get; set; } = null!; - [Column("priority")] public DbJobPriority Priority { get; set; } + public string JobName { get; set; } = null!; + public DbJobPriority Priority { get; set; } } public enum DbJobPriority @@ -142,14 +138,13 @@ namespace Content.Server.Database High = 3 } - [Table("antag")] public class Antag { - [Column("antag_id")] public int Id { get; set; } + public int Id { get; set; } public Profile Profile { get; set; } = null!; - [Column("profile_id")] public int ProfileId { get; set; } + public int ProfileId { get; set; } - [Column("antag_name")] public string AntagName { get; set; } = null!; + public string AntagName { get; set; } = null!; } public enum DbPreferenceUnavailableMode @@ -159,54 +154,49 @@ namespace Content.Server.Database SpawnAsOverflow, } - [Table("assigned_user_id")] public class AssignedUserId { - [Column("assigned_user_id_id")] public int Id { get; set; } - [Column("user_name")] public string UserName { get; set; } = null!; + public int Id { get; set; } + public string UserName { get; set; } = null!; - [Column("user_id")] public Guid UserId { get; set; } + public Guid UserId { get; set; } } - [Table("admin")] public class Admin { - [Column("user_id"), Key] public Guid UserId { get; set; } - [Column("title")] public string? Title { get; set; } + [Key] public Guid UserId { get; set; } + public string? Title { get; set; } - [Column("admin_rank_id")] public int? AdminRankId { get; set; } + public int? AdminRankId { get; set; } public AdminRank? AdminRank { get; set; } public List Flags { get; set; } = default!; } - [Table("admin_flag")] public class AdminFlag { - [Column("admin_flag_id")] public int Id { get; set; } - [Column("flag")] public string Flag { get; set; } = default!; - [Column("negative")] public bool Negative { get; set; } + public int Id { get; set; } + public string Flag { get; set; } = default!; + public bool Negative { get; set; } - [Column("admin_id")] public Guid AdminId { get; set; } + public Guid AdminId { get; set; } public Admin Admin { get; set; } = default!; } - [Table("admin_rank")] public class AdminRank { - [Column("admin_rank_id")] public int Id { get; set; } - [Column("name")] public string Name { get; set; } = default!; + public int Id { get; set; } + public string Name { get; set; } = default!; public List Admins { get; set; } = default!; public List Flags { get; set; } = default!; } - [Table("admin_rank_flag")] public class AdminRankFlag { - [Column("admin_rank_flag_id")] public int Id { get; set; } - [Column("flag")] public string Flag { get; set; } = default!; + public int Id { get; set; } + public string Flag { get; set; } = default!; - [Column("admin_rank_id")] public int AdminRankId { get; set; } + public int AdminRankId { get; set; } public AdminRank Rank { get; set; } = default!; } } diff --git a/Content.Server.Database/ModelPostgres.cs b/Content.Server.Database/ModelPostgres.cs index f36536fa49..2ee28db9a5 100644 --- a/Content.Server.Database/ModelPostgres.cs +++ b/Content.Server.Database/ModelPostgres.cs @@ -2,6 +2,7 @@ using System.ComponentModel.DataAnnotations.Schema; using System.Net; using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Storage; namespace Content.Server.Database @@ -25,6 +26,8 @@ namespace Content.Server.Database options.UseNpgsql("dummy connection string"); options.ReplaceService(); + + ((IDbContextOptionsBuilderInfrastructure) options).AddOrUpdateExtension(new SnakeCaseExtension()); } public PostgresServerDbContext(DbContextOptions options) : base(options) @@ -74,26 +77,32 @@ namespace Content.Server.Database modelBuilder.Entity() .HasCheckConstraint("AddressNotIPv6MappedIPv4", "NOT inet '::ffff:0.0.0.0/96' >>= address"); + + foreach(var entity in modelBuilder.Model.GetEntityTypes()) + { + foreach(var property in entity.GetProperties()) + { + if (property.FieldInfo.FieldType == typeof(DateTime) || property.FieldInfo.FieldType == typeof(DateTime?)) + property.SetColumnType("timestamp with time zone"); + } + } } } [Table("server_ban")] public class PostgresServerBan { - [Column("server_ban_id")] public int Id { get; set; } + public int Id { get; set; } + public Guid? UserId { get; set; } + [Column(TypeName = "inet")] public (IPAddress, int)? Address { get; set; } + public byte[]? HWId { get; set; } - [Column("user_id")] public Guid? UserId { get; set; } - [Column("address", TypeName = "inet")] public (IPAddress, int)? Address { get; set; } - [Column("hwid")] public byte[]? HWId { get; set; } - - [Column("ban_time", TypeName = "timestamp with time zone")] public DateTime BanTime { get; set; } - [Column("expiration_time", TypeName = "timestamp with time zone")] public DateTime? ExpirationTime { get; set; } - [Column("reason")] public string Reason { get; set; } = null!; - [Column("banning_admin")] public Guid? BanningAdmin { get; set; } + public string Reason { get; set; } = null!; + public Guid? BanningAdmin { get; set; } public PostgresServerUnban? Unban { get; set; } } @@ -103,48 +112,44 @@ namespace Content.Server.Database { [Column("unban_id")] public int Id { get; set; } - [Column("ban_id")] public int BanId { get; set; } - [Column("ban")] public PostgresServerBan Ban { get; set; } = null!; + public int BanId { get; set; } + public PostgresServerBan Ban { get; set; } = null!; - [Column("unbanning_admin")] public Guid? UnbanningAdmin { get; set; } + public Guid? UnbanningAdmin { get; set; } - [Column("unban_time", TypeName = "timestamp with time zone")] public DateTime UnbanTime { get; set; } } [Table("player")] public class PostgresPlayer { - [Column("player_id")] public int Id { get; set; } + public int Id { get; set; } // Permanent data - [Column("user_id")] public Guid UserId { get; set; } + public Guid UserId { get; set; } - [Column("first_seen_time", TypeName = "timestamp with time zone")] public DateTime FirstSeenTime { get; set; } // Data that gets updated on each join. - [Column("last_seen_user_name")] public string LastSeenUserName { get; set; } = null!; + public string LastSeenUserName { get; set; } = null!; - [Column("last_seen_time", TypeName = "timestamp with time zone")] public DateTime LastSeenTime { get; set; } - [Column("last_seen_address")] public IPAddress LastSeenAddress { get; set; } = null!; - [Column("last_seen_hwid")] public byte[]? LastSeenHWId { get; set; } + public IPAddress LastSeenAddress { get; set; } = null!; + public byte[]? LastSeenHWId { get; set; } } [Table("connection_log")] public class PostgresConnectionLog { - [Column("connection_log_id")] public int Id { get; set; } + public int Id { get; set; } - [Column("user_id")] public Guid UserId { get; set; } - [Column("user_name")] public string UserName { get; set; } = null!; + public Guid UserId { get; set; } + public string UserName { get; set; } = null!; - [Column("time", TypeName = "timestamp with time zone")] public DateTime Time { get; set; } - [Column("address")] public IPAddress Address { get; set; } = null!; - [Column("hwid")] public byte[]? HWId { get; set; } + public IPAddress Address { get; set; } = null!; + public byte[]? HWId { get; set; } } } diff --git a/Content.Server.Database/ModelSqlite.cs b/Content.Server.Database/ModelSqlite.cs index e82c16cf72..0eaaf3072c 100644 --- a/Content.Server.Database/ModelSqlite.cs +++ b/Content.Server.Database/ModelSqlite.cs @@ -1,6 +1,10 @@ using System; using System.ComponentModel.DataAnnotations.Schema; +using System.Globalization; +using System.Net; using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; namespace Content.Server.Database { @@ -19,6 +23,8 @@ namespace Content.Server.Database { if (!InitializedWithOptions) options.UseSqlite("dummy connection string"); + + ((IDbContextOptionsBuilderInfrastructure) options).AddOrUpdateExtension(new SnakeCaseExtension()); } protected override void OnModelCreating(ModelBuilder modelBuilder) @@ -27,26 +33,56 @@ namespace Content.Server.Database modelBuilder.Entity() .HasIndex(p => p.LastSeenUserName); + + var converter = new ValueConverter<(IPAddress address, int mask), string>( + v => InetToString(v.address, v.mask), + v => StringToInet(v) + ); + + modelBuilder + .Entity() + .Property(e => e.Address) + .HasColumnType("TEXT") + .HasConversion(converter); } public SqliteServerDbContext(DbContextOptions options) : base(options) { } + + private static string InetToString(IPAddress address, int mask) { + if (address.IsIPv4MappedToIPv6) + { + // Fix IPv6-mapped IPv4 addresses + // So that IPv4 addresses are consistent between separate-socket and dual-stack socket modes. + address = address.MapToIPv4(); + mask -= 96; + } + return $"{address}/{mask}"; + } + + private static (IPAddress, int) StringToInet(string inet) { + var idx = inet.IndexOf('/', StringComparison.Ordinal); + return ( + IPAddress.Parse(inet.AsSpan(0, idx)), + int.Parse(inet.AsSpan(idx + 1), provider: CultureInfo.InvariantCulture) + ); + } } [Table("ban")] public class SqliteServerBan { - [Column("ban_id")] public int Id { get; set; } + public int Id { get; set; } - [Column("user_id")] public Guid? UserId { get; set; } - [Column("address")] public string? Address { get; set; } - [Column("hwid")] public byte[]? HWId { get; set; } + public Guid? UserId { get; set; } + public (IPAddress address, int mask)? Address { get; set; } + public byte[]? HWId { get; set; } - [Column("ban_time")] public DateTime BanTime { get; set; } - [Column("expiration_time")] public DateTime? ExpirationTime { get; set; } - [Column("reason")] public string Reason { get; set; } = null!; - [Column("banning_admin")] public Guid? BanningAdmin { get; set; } + public DateTime BanTime { get; set; } + public DateTime? ExpirationTime { get; set; } + public string Reason { get; set; } = null!; + public Guid? BanningAdmin { get; set; } public SqliteServerUnban? Unban { get; set; } } @@ -56,38 +92,38 @@ namespace Content.Server.Database { [Column("unban_id")] public int Id { get; set; } - [Column("ban_id")] public int BanId { get; set; } + public int BanId { get; set; } public SqliteServerBan Ban { get; set; } = null!; - [Column("unbanning_admin")] public Guid? UnbanningAdmin { get; set; } - [Column("unban_time")] public DateTime UnbanTime { get; set; } + public Guid? UnbanningAdmin { get; set; } + public DateTime UnbanTime { get; set; } } [Table("player")] public class SqlitePlayer { - [Column("player_id")] public int Id { get; set; } + public int Id { get; set; } // Permanent data - [Column("user_id")] public Guid UserId { get; set; } - [Column("first_seen_time")] public DateTime FirstSeenTime { get; set; } + public Guid UserId { get; set; } + public DateTime FirstSeenTime { get; set; } // Data that gets updated on each join. - [Column("last_seen_user_name")] public string LastSeenUserName { get; set; } = null!; - [Column("last_seen_time")] public DateTime LastSeenTime { get; set; } - [Column("last_seen_address")] public string LastSeenAddress { get; set; } = null!; - [Column("last_seen_hwid")] public byte[]? LastSeenHWId { get; set; } + public string LastSeenUserName { get; set; } = null!; + public DateTime LastSeenTime { get; set; } + public string LastSeenAddress { get; set; } = null!; + public byte[]? LastSeenHWId { get; set; } } [Table("connection_log")] public class SqliteConnectionLog { - [Column("connection_log_id")] public int Id { get; set; } + public int Id { get; set; } - [Column("user_id")] public Guid UserId { get; set; } - [Column("user_name")] public string UserName { get; set; } = null!; - [Column("time")] public DateTime Time { get; set; } - [Column("address")] public string Address { get; set; } = null!; - [Column("hwid")] public byte[]? HWId { get; set; } + public Guid UserId { get; set; } + public string UserName { get; set; } = null!; + public DateTime Time { get; set; } + public string Address { get; set; } = null!; + public byte[]? HWId { get; set; } } } diff --git a/Content.Server.Database/SnakeCaseNaming.cs b/Content.Server.Database/SnakeCaseNaming.cs new file mode 100644 index 0000000000..e9b204d5ac --- /dev/null +++ b/Content.Server.Database/SnakeCaseNaming.cs @@ -0,0 +1,322 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Metadata.Builders; +using Microsoft.EntityFrameworkCore.Metadata.Conventions; +using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure; +using Microsoft.Extensions.DependencyInjection; + +namespace Content.Server.Database +{ + public class SnakeCaseExtension : IDbContextOptionsExtension + { + public DbContextOptionsExtensionInfo Info { get; } + + public SnakeCaseExtension() { + Info = new ExtensionInfo(this); + } + + public void ApplyServices(IServiceCollection services) + => services.AddSnakeCase(); + + public void Validate(IDbContextOptions options) {} + + private sealed class ExtensionInfo : DbContextOptionsExtensionInfo + { + public ExtensionInfo(IDbContextOptionsExtension extension) : base(extension) {} + + public override bool IsDatabaseProvider => false; + + public override string LogFragment => "Snake Case Extension"; + + public override long GetServiceProviderHashCode() => 0; + + public override void PopulateDebugInfo(IDictionary debugInfo) + { + } + } + } + + public static class SnakeCaseServiceCollectionExtensions + { + public static IServiceCollection AddSnakeCase( + this IServiceCollection serviceCollection) + { + new EntityFrameworkServicesBuilder(serviceCollection) + .TryAdd(); + + return serviceCollection; + } + } + + public class SnakeCaseConventionSetPlugin : IConventionSetPlugin + { + public ConventionSet ModifyConventions(ConventionSet conventionSet) + { + var convention = new SnakeCaseConvention(); + + conventionSet.EntityTypeAddedConventions.Add(convention); + conventionSet.EntityTypeAnnotationChangedConventions.Add(convention); + conventionSet.PropertyAddedConventions.Add(convention); + conventionSet.ForeignKeyOwnershipChangedConventions.Add(convention); + conventionSet.KeyAddedConventions.Add(convention); + conventionSet.ForeignKeyAddedConventions.Add(convention); + conventionSet.EntityTypeBaseTypeChangedConventions.Add(convention); + conventionSet.ModelFinalizingConventions.Add(convention); + + return conventionSet; + } + } + + public class SnakeCaseConvention : + IEntityTypeAddedConvention, + IEntityTypeAnnotationChangedConvention, + IPropertyAddedConvention, + IForeignKeyOwnershipChangedConvention, + IKeyAddedConvention, + IForeignKeyAddedConvention, + IEntityTypeBaseTypeChangedConvention, + IModelFinalizingConvention + { + private static readonly StoreObjectType[] _storeObjectTypes + = { StoreObjectType.Table, StoreObjectType.View, StoreObjectType.Function, StoreObjectType.SqlQuery }; + + public SnakeCaseConvention() {} + + public static string RewriteName(string name) + { + var regex = new Regex("[A-Z]+", RegexOptions.Compiled); + return regex.Replace( + name, + (Match match) => { + if (match.Index == 0 && (match.Value == "FK" || match.Value == "PK" || match.Value == "IX")) { + return match.Value; + } + if (match.Value == "HWI") + return (match.Index == 0 ? "" : "_") + "hwi"; + if (match.Index == 0) + return match.Value.ToLower(); + if (match.Length > 1) + return $"_{match.Value[..^1].ToLower()}_{match.Value[^1..^0].ToLower()}"; + return "_" + match.Value.ToLower(); + } + ); + } + + public virtual void ProcessEntityTypeAdded( + IConventionEntityTypeBuilder entityTypeBuilder, + IConventionContext context) + { + var entityType = entityTypeBuilder.Metadata; + + if (entityType.ClrType == typeof(Microsoft.EntityFrameworkCore.Migrations.HistoryRow)) + return; + + if (entityType.BaseType is null) + { + entityTypeBuilder.ToTable(RewriteName(entityType.GetTableName()), entityType.GetSchema()); + + if (entityType.GetViewNameConfigurationSource() == ConfigurationSource.Convention) + { + entityTypeBuilder.ToView(RewriteName(entityType.GetViewName()), entityType.GetViewSchema()); + } + } + } + + public void ProcessEntityTypeBaseTypeChanged( + IConventionEntityTypeBuilder entityTypeBuilder, + IConventionEntityType newBaseType, + IConventionEntityType oldBaseType, + IConventionContext context) + { + var entityType = entityTypeBuilder.Metadata; + + if (newBaseType is null) + { + entityTypeBuilder.ToTable(RewriteName(entityType.GetTableName()), entityType.GetSchema()); + } + else + { + entityTypeBuilder.HasNoAnnotation(RelationalAnnotationNames.TableName); + entityTypeBuilder.HasNoAnnotation(RelationalAnnotationNames.Schema); + } + } + + public virtual void ProcessPropertyAdded( + IConventionPropertyBuilder propertyBuilder, + IConventionContext context) + => RewriteColumnName(propertyBuilder); + + public void ProcessForeignKeyOwnershipChanged(IConventionForeignKeyBuilder relationshipBuilder, IConventionContext context) + { + var foreignKey = relationshipBuilder.Metadata; + var ownedEntityType = foreignKey.DeclaringEntityType; + + if (foreignKey.IsOwnership && ownedEntityType.GetTableNameConfigurationSource() != ConfigurationSource.Explicit) + { + ownedEntityType.Builder.HasNoAnnotation(RelationalAnnotationNames.TableName); + ownedEntityType.Builder.HasNoAnnotation(RelationalAnnotationNames.Schema); + + ownedEntityType.FindPrimaryKey()?.Builder.HasNoAnnotation(RelationalAnnotationNames.Name); + + foreach (var property in ownedEntityType.GetProperties()) + { + RewriteColumnName(property.Builder); + } + } + } + + public void ProcessEntityTypeAnnotationChanged( + IConventionEntityTypeBuilder entityTypeBuilder, + string name, + IConventionAnnotation annotation, + IConventionAnnotation oldAnnotation, + IConventionContext context) + { + var entityType = entityTypeBuilder.Metadata; + + if (entityType.ClrType == typeof(Microsoft.EntityFrameworkCore.Migrations.HistoryRow)) + return; + + if (name != RelationalAnnotationNames.TableName + || StoreObjectIdentifier.Create(entityType, StoreObjectType.Table) is not StoreObjectIdentifier tableIdentifier) + { + return; + } + + if (entityType.FindPrimaryKey() is IConventionKey primaryKey) + { + if (entityType.FindRowInternalForeignKeys(tableIdentifier).FirstOrDefault() is null + && (entityType.BaseType is null || entityType.GetTableName() == entityType.BaseType.GetTableName())) + { + primaryKey.Builder.HasName(RewriteName(primaryKey.GetDefaultName())); + } + else + { + primaryKey.Builder.HasNoAnnotation(RelationalAnnotationNames.Name); + } + } + + foreach (var foreignKey in entityType.GetForeignKeys()) + { + foreignKey.Builder.HasConstraintName(RewriteName(foreignKey.GetDefaultName())); + } + + foreach (var index in entityType.GetIndexes()) + { + index.Builder.HasDatabaseName(RewriteName(index.GetDefaultDatabaseName())); + } + + if (annotation?.Value is not null + && entityType.FindOwnership() is IConventionForeignKey ownership + && (string)annotation.Value != ownership.PrincipalEntityType.GetTableName()) + { + foreach (var property in entityType.GetProperties() + .Except(entityType.FindPrimaryKey().Properties) + .Where(p => p.Builder.CanSetColumnName(null))) + { + RewriteColumnName(property.Builder); + } + + if (entityType.FindPrimaryKey() is IConventionKey key) + { + key.Builder.HasName(RewriteName(key.GetDefaultName())); + } + } + } + + public void ProcessForeignKeyAdded( + IConventionForeignKeyBuilder relationshipBuilder, + IConventionContext context) + => relationshipBuilder.HasConstraintName(RewriteName(relationshipBuilder.Metadata.GetDefaultName())); + + public void ProcessKeyAdded(IConventionKeyBuilder keyBuilder, IConventionContext context) + { + var entityType = keyBuilder.Metadata.DeclaringEntityType; + + if (entityType.ClrType == typeof(Microsoft.EntityFrameworkCore.Migrations.HistoryRow)) + return; + + if (entityType.FindOwnership() is null) + { + keyBuilder.HasName(RewriteName(keyBuilder.Metadata.GetDefaultName())); + } + } + + public void ProcessModelFinalizing(IConventionModelBuilder modelBuilder, IConventionContext context) + { + foreach (var entityType in modelBuilder.Metadata.GetEntityTypes()) + { + if (entityType.ClrType == typeof(Microsoft.EntityFrameworkCore.Migrations.HistoryRow)) + continue; + + foreach (var property in entityType.GetProperties()) + { + var columnName = property.GetColumnBaseName(); + if (columnName.StartsWith(entityType.ShortName() + '_', StringComparison.Ordinal)) + { + property.Builder.HasColumnName( + RewriteName(entityType.ShortName()) + columnName[entityType.ShortName().Length..]); + } + + foreach (var storeObjectType in _storeObjectTypes) + { + var identifier = StoreObjectIdentifier.Create(entityType, storeObjectType); + if (identifier is null) + continue; + + if (property.GetColumnNameConfigurationSource(identifier.Value) == ConfigurationSource.Convention) + { + columnName = property.GetColumnName(identifier.Value); + if (columnName.StartsWith(entityType.ShortName() + '_', StringComparison.Ordinal)) + { + property.Builder.HasColumnName( + RewriteName(entityType.ShortName()) + + columnName[entityType.ShortName().Length..]); + } + } + } + } + } + } + + private void RewriteColumnName(IConventionPropertyBuilder propertyBuilder) + { + var property = propertyBuilder.Metadata; + var entityType = property.DeclaringEntityType; + + if (entityType.ClrType == typeof(Microsoft.EntityFrameworkCore.Migrations.HistoryRow)) + return; + + property.Builder.HasNoAnnotation(RelationalAnnotationNames.ColumnName); + + var baseColumnName = StoreObjectIdentifier.Create(property.DeclaringEntityType, StoreObjectType.Table) is { } tableIdentifier + ? property.GetDefaultColumnName(tableIdentifier) + : property.GetDefaultColumnBaseName(); + + if (baseColumnName == "Id") + baseColumnName = entityType.GetTableName() + baseColumnName; + propertyBuilder.HasColumnName(RewriteName(baseColumnName)); + + foreach (var storeObjectType in _storeObjectTypes) + { + var identifier = StoreObjectIdentifier.Create(entityType, storeObjectType); + if (identifier is null) + continue; + + if (property.GetColumnNameConfigurationSource(identifier.Value) == ConfigurationSource.Convention) + { + var name = property.GetColumnName(identifier.Value); + if (name == "Id") + name = entityType.GetTableName() + name; + propertyBuilder.HasColumnName( + RewriteName(name), identifier.Value); + } + } + } + } +} diff --git a/Content.Server.Database/add-migration.sh b/Content.Server.Database/add-migration.sh new file mode 100755 index 0000000000..cc252726dd --- /dev/null +++ b/Content.Server.Database/add-migration.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +if [ -z "$1" ] ; then + echo "Must specify migration name" + exit 1 +fi + +dotnet ef migrations add --context SqliteServerDbContext -o Migrations/Sqlite "$1" +dotnet ef migrations add --context PostgresServerDbContext -o Migrations/Postgres "$1" diff --git a/Content.Server/Database/ServerDbManager.cs b/Content.Server/Database/ServerDbManager.cs index b0d2d032e9..d8f419046f 100644 --- a/Content.Server/Database/ServerDbManager.cs +++ b/Content.Server/Database/ServerDbManager.cs @@ -313,6 +313,9 @@ namespace Content.Server.Database Username = user, Password = pass }.ConnectionString; + + Logger.DebugS("db.manager", $"Using Postgres \"{host}:{port}/{db}\""); + builder.UseNpgsql(connectionString); SetupLogging(builder); return builder.Options; @@ -329,10 +332,12 @@ namespace Content.Server.Database if (!inMemory) { var finalPreferencesDbPath = Path.Combine(_res.UserData.RootDir!, configPreferencesDbPath); + Logger.DebugS("db.manager", $"Using SQLite DB \"{finalPreferencesDbPath}\""); connection = new SqliteConnection($"Data Source={finalPreferencesDbPath}"); } else { + Logger.DebugS("db.manager", $"Using in-memory SQLite DB"); connection = new SqliteConnection("Data Source=:memory:"); // When using an in-memory DB we have to open it manually // so EFCore doesn't open, close and wipe it. diff --git a/Content.Server/Database/ServerDbSqlite.cs b/Content.Server/Database/ServerDbSqlite.cs index bd840db20d..a82a9e0b39 100644 --- a/Content.Server/Database/ServerDbSqlite.cs +++ b/Content.Server/Database/ServerDbSqlite.cs @@ -104,7 +104,7 @@ namespace Content.Server.Database NetUserId? userId, ImmutableArray? hwId) { - if (address != null && ban.Address != null && IPAddressExt.IsInSubnet(address, ban.Address)) + if (address != null && ban.Address is not null && IPAddressExt.IsInSubnet(address, ban.Address.Value)) { return true; } @@ -126,15 +126,9 @@ namespace Content.Server.Database { await using var db = await GetDbImpl(); - string? addrStr = null; - if (serverBan.Address is { } addr) - { - addrStr = $"{addr.address}/{addr.cidrMask}"; - } - db.SqliteDbContext.Ban.Add(new SqliteServerBan { - Address = addrStr, + Address = serverBan.Address, Reason = serverBan.Reason, BanningAdmin = serverBan.BanningAdmin?.UserId, HWId = serverBan.HWId?.ToArray(), @@ -245,20 +239,12 @@ namespace Content.Server.Database aUid = new NetUserId(aGuid); } - (IPAddress, int)? addrTuple = null; - if (ban.Address != null) - { - var idx = ban.Address.IndexOf('/', StringComparison.Ordinal); - addrTuple = (IPAddress.Parse(ban.Address.AsSpan(0, idx)), - int.Parse(ban.Address.AsSpan(idx + 1), provider: CultureInfo.InvariantCulture)); - } - var unban = ConvertUnban(ban.Unban); return new ServerBanDef( ban.Id, uid, - addrTuple, + ban.Address, ban.HWId == null ? null : ImmutableArray.Create(ban.HWId), ban.BanTime, ban.ExpirationTime, diff --git a/Content.Server/IP/IPAddressExt.cs b/Content.Server/IP/IPAddressExt.cs index 51230e20cd..269e4e8589 100644 --- a/Content.Server/IP/IPAddressExt.cs +++ b/Content.Server/IP/IPAddressExt.cs @@ -18,7 +18,7 @@ namespace Content.Server.IP } // First parse the address of the netmask before the prefix length. - var maskAddress = System.Net.IPAddress.Parse(subnetMask.Substring(0, slashIdx)); + var maskAddress = System.Net.IPAddress.Parse(subnetMask[..slashIdx]); if (maskAddress.AddressFamily != address.AddressFamily) { @@ -27,8 +27,18 @@ namespace Content.Server.IP } // Now find out how long the prefix is. - int maskLength = int.Parse(subnetMask.Substring(slashIdx + 1)); + int maskLength = int.Parse(subnetMask[(slashIdx + 1)..]); + return address.IsInSubnet(maskAddress, maskLength); + } + + public static bool IsInSubnet(this System.Net.IPAddress address, (System.Net.IPAddress maskAddress, int maskLength) tuple) + { + return address.IsInSubnet(tuple.maskAddress, tuple.maskLength); + } + + public static bool IsInSubnet(this System.Net.IPAddress address, System.Net.IPAddress maskAddress, int maskLength) + { if (maskAddress.AddressFamily == AddressFamily.InterNetwork) { // Convert the mask address to an unsigned integer.