WESL-based shader compilation

This commit is contained in:
PJB3005
2025-10-09 03:03:23 +02:00
parent 312944eb9a
commit 3d6fda1aca
22 changed files with 456 additions and 74 deletions

View File

@@ -0,0 +1,27 @@
import Robust::SpriteBatch::{VertexInput, VertexOutput, mainTexture, mainSampler, View};
import Robust::Math::srgb_to_linear;
@vertex
fn vs_main(input: VertexInput) -> VertexOutput {
var transformed = vec3(input.position, 1.0) * View.projViewMatrix;
transformed += 1.0;
transformed /= View.screenPixelSize * 2.0;
transformed = floor(transformed + 0.5);
transformed *= View.screenPixelSize * 2.0;
transformed -= 1.0;
var out: VertexOutput;
out.position = vec4(transformed, 0.0, 1.0);
out.texCoord = input.texCoord;
out.color = srgb_to_linear(input.color);
return out;
}
@fragment
fn fs_main(input: VertexOutput) -> @location(0) vec4f {
var color = textureSample(mainTexture, mainSampler, input.texCoord);
color = color * input.color;
return color;
}

View File

@@ -0,0 +1,6 @@
fn srgb_to_linear(srgb: vec4f) -> vec4f {
let higher = pow((srgb.rgb + 0.055) / 1.055, vec3(2.4));
let lower = srgb.rgb / 12.92;
let s = max(vec3(0.0), sign(srgb.rgb - 0.04045));
return vec4(mix(lower, higher, s), srgb.a);
}

View File

@@ -0,0 +1,33 @@
// Group 0: global constants.
struct UniformConstants {
time: f32
}
@group(0) @binding(0) var<uniform> Constants: UniformConstants;
// Group 1: parameters that change infrequently in a draw pass.
struct UniformView {
projViewMatrix: mat2x3f,
screenPixelSize: vec2f
}
@group(1) @binding(0) var<uniform> View: UniformView;
// Group 2: per-draw parameters.
@group(2) @binding(0)
var mainTexture: texture_2d<f32>;
@group(2) @binding(1)
var mainSampler: sampler;
struct VertexInput {
@location(0) position: vec2f,
@location(1) texCoord: vec2f,
@location(2) color: vec4f
}
struct VertexOutput {
@builtin(position) position: vec4f,
@location(0) texCoord: vec2f,
@location(1) color: vec4f,
}

View File

@@ -1,66 +0,0 @@
// Group 0: global constants.
struct UniformConstants {
time: f32
}
@group(0) @binding(0) var<uniform> Constants: UniformConstants;
// Group 1: parameters that change infrequently in a draw pass.
struct UniformView {
projViewMatrix: mat2x3f,
screenPixelSize: vec2f
}
@group(1) @binding(0) var<uniform> View: UniformView;
// Group 2: per-draw parameters.
@group(2) @binding(0)
var mainTexture: texture_2d<f32>;
@group(2) @binding(1)
var mainSampler: sampler;
struct VertexInput {
@location(0) position: vec2f,
@location(1) texCoord: vec2f,
@location(2) color: vec4f
}
struct VertexOutput {
@builtin(position) position: vec4f,
@location(0) texCoord: vec2f,
@location(1) color: vec4f,
}
@vertex
fn vs_main(input: VertexInput) -> VertexOutput {
var transformed = vec3(input.position, 1.0) * View.projViewMatrix;
transformed += 1.0;
transformed /= View.screenPixelSize * 2.0;
transformed = floor(transformed + 0.5);
transformed *= View.screenPixelSize * 2.0;
transformed -= 1.0;
var out: VertexOutput;
out.position = vec4(transformed, 0.0, 1.0);
out.texCoord = input.texCoord;
out.color = srgb_to_linear(input.color);
return out;
}
@fragment
fn fs_main(input: VertexOutput) -> @location(0) vec4f {
var color = textureSample(mainTexture, mainSampler, input.texCoord);
color = color * input.color;
return color;
}
fn srgb_to_linear(srgb: vec4f) -> vec4f {
let higher = pow((srgb.rgb + 0.055) / 1.055, vec3(2.4));
let lower = srgb.rgb / 12.92;
let s = max(vec3(0.0), sign(srgb.rgb - 0.04045));
return vec4(mix(lower, higher, s), srgb.a);
}

View File

@@ -38,6 +38,7 @@ public abstract partial class RhiBase
public abstract RhiSampler CreateSampler(in RhiSamplerDescriptor descriptor);
public abstract RhiShaderModule CreateShaderModule(in RhiShaderModuleDescriptor descriptor);
public abstract RhiShaderModule CreateShaderModule(in RhiShaderModuleDescriptorUtf8 descriptor);
public abstract RhiPipelineLayout CreatePipelineLayout(in RhiPipelineLayoutDescriptor descriptor);
@@ -803,6 +804,13 @@ public record struct RhiShaderModuleDescriptor(
string? Label
);
public record struct RhiShaderModuleDescriptorUtf8(
// TODO: Hints
// TODO: source map ?
byte[] Code,
string? Label
);
public record struct RhiImageCopyTexture(
RhiTexture Texture,
uint MipLevel = 0,

View File

@@ -10,8 +10,17 @@ internal sealed unsafe partial class RhiWebGpu
{
var codeBytes = Encoding.UTF8.GetBytes(descriptor.Code);
return CreateShaderModule(new RhiShaderModuleDescriptorUtf8
{
Code = codeBytes,
Label = descriptor.Label
});
}
public override RhiShaderModule CreateShaderModule(in RhiShaderModuleDescriptorUtf8 descriptor)
{
WGPUShaderModule shaderModule;
fixed (byte* pCode = codeBytes)
fixed (byte* pCode = descriptor.Code)
fixed (byte* pLabel = MakeLabel(descriptor.Label))
{
var descWgsl = new WGPUShaderSourceWGSL();

View File

@@ -13,6 +13,9 @@ internal static unsafe partial class Wesl
[DllImport("robust-native", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern WeslResult wesl_compile([NativeTypeName("const WeslStringMap *")] WeslStringMap* files, [NativeTypeName("const char *")] sbyte* root, [NativeTypeName("const WeslCompileOptions *")] WeslCompileOptions* options, [NativeTypeName("const WeslStringArray *")] WeslStringArray* keep, [NativeTypeName("const WeslBoolMap *")] WeslBoolMap* features);
[DllImport("robust-native", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern WeslParseResult wesl_parse([NativeTypeName("const char *")] sbyte* source);
[DllImport("robust-native", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern WeslResult wesl_eval([NativeTypeName("const WeslStringMap *")] WeslStringMap* files, [NativeTypeName("const char *")] sbyte* root, [NativeTypeName("const char *")] sbyte* expression, [NativeTypeName("const WeslCompileOptions *")] WeslCompileOptions* options, [NativeTypeName("const WeslBoolMap *")] WeslBoolMap* features);
@@ -28,6 +31,9 @@ internal static unsafe partial class Wesl
[DllImport("robust-native", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void wesl_free_exec_result(WeslExecResult* result);
[DllImport("robust-native", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void wesl_free_translation_unit(WeslTranslationUnit* unit);
[DllImport("robust-native", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("const char *")]
public static extern sbyte* wesl_version();

View File

@@ -1,6 +1,6 @@
namespace Robust.Client.Interop.RobustNative.Wesl.Gen;
internal partial struct WeslCompileOptions
internal unsafe partial struct WeslCompileOptions
{
public WeslManglerKind mangler;
@@ -36,4 +36,6 @@ internal partial struct WeslCompileOptions
[NativeTypeName("_Bool")]
public byte mangle_root;
public WeslResolverOptions* resolver;
}

View File

@@ -0,0 +1,11 @@
namespace Robust.Client.Interop.RobustNative.Wesl.Gen;
internal unsafe partial struct WeslParseResult
{
[NativeTypeName("_Bool")]
public byte success;
public WeslTranslationUnit* data;
public WeslError error;
}

View File

@@ -0,0 +1,9 @@
namespace Robust.Client.Interop.RobustNative.Wesl.Gen;
internal unsafe partial struct WeslResolveModuleResult
{
[NativeTypeName("_Bool")]
public byte success;
public WeslTranslationUnit* module;
}

View File

@@ -0,0 +1,10 @@
namespace Robust.Client.Interop.RobustNative.Wesl.Gen;
internal unsafe partial struct WeslResolveSourceResult
{
[NativeTypeName("_Bool")]
public byte success;
[NativeTypeName("const char *")]
public sbyte* source;
}

View File

@@ -0,0 +1,30 @@
namespace Robust.Client.Interop.RobustNative.Wesl.Gen;
internal unsafe partial struct WeslResolverOptions
{
public void* userdata;
[NativeTypeName("WeslResolveSourceFunction")]
public delegate* unmanaged[Cdecl]<sbyte*, void*, WeslResolveSourceResult*> resolve_source;
[NativeTypeName("WeslResolveSourceFreeFunction")]
public delegate* unmanaged[Cdecl]<WeslResolveSourceResult*, void*, void> resolve_source_free;
[NativeTypeName("WeslResolveModuleFunction")]
public delegate* unmanaged[Cdecl]<sbyte*, void*, WeslResolveModuleResult*> resolve_module;
[NativeTypeName("WeslResolveModuleFreeFunction")]
public delegate* unmanaged[Cdecl]<WeslResolveModuleResult*, void*, void> resolve_module_free;
[NativeTypeName("WeslResolveStringFunction")]
public delegate* unmanaged[Cdecl]<sbyte*, void*, sbyte*> display_name;
[NativeTypeName("WeslResolveFreeStringFunction")]
public delegate* unmanaged[Cdecl]<sbyte*, void*, void> free_display_name;
[NativeTypeName("WeslResolveStringFunction")]
public delegate* unmanaged[Cdecl]<sbyte*, void*, sbyte*> fs_path;
[NativeTypeName("WeslResolveFreeStringFunction")]
public delegate* unmanaged[Cdecl]<sbyte*, void*, void> free_fs_path;
}

View File

@@ -0,0 +1,5 @@
namespace Robust.Client.Interop.RobustNative.Wesl.Gen;
internal partial struct WeslTranslationUnit
{
}

View File

@@ -99,6 +99,7 @@ namespace Robust.Client
deps.Register<IUserInterfaceManager, UserInterfaceManager>();
deps.Register<IUserInterfaceManagerInternal, UserInterfaceManager>();
deps.Register<ILightManager, LightManager>();
deps.Register<IShaderCompiler, ShaderCompiler>();
deps.Register<IDiscordRichPresence, DiscordRichPresence>();
deps.Register<IMidiManager, MidiManager>();
deps.Register<IAuthManager, AuthManager>();

View File

@@ -0,0 +1,46 @@
using System.Collections.Immutable;
using System.Text;
using Robust.Client.Graphics;
using Robust.Shared.Console;
using Robust.Shared.ContentPack;
using Robust.Shared.IoC;
using Robust.Shared.Utility;
namespace Robust.Client.Console.Commands;
internal sealed class CompileShaderCommand : IConsoleCommand
{
[Dependency] private readonly IShaderCompiler _shaderCompiler = null!;
[Dependency] private readonly IResourceManager _resourceManager = null!;
public string Command => "compile_shader";
public string Description => "";
public string Help => "";
public void Execute(IConsoleShell shell, string argStr, string[] args)
{
var path = args[0];
var x = _shaderCompiler.CompileToWgsl(new ResPath(path), ImmutableDictionary<string, bool>.Empty);
if (!x.Success)
{
shell.WriteError("Compilation failed");
return;
}
var codeText = Encoding.UTF8.GetString(x.Code);
shell.WriteLine(codeText);
}
public CompletionResult GetCompletion(IConsoleShell shell, string[] args)
{
if (args.Length == 1)
{
return CompletionResult.FromHintOptions(
CompletionHelper.ContentFilePath(args[0], _resourceManager),
"<path>");
}
return CompletionResult.Empty;
}
}

View File

@@ -1,4 +1,5 @@
using System;
using System.Collections.Immutable;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
@@ -89,12 +90,13 @@ internal partial class Clyde
RhiBufferUsageFlags.Uniform | RhiBufferUsageFlags.CopyDst,
label: "_uniformPassBuffer");
var res = clyde._resourceCache;
var shaderSource = res.ContentFileReadAllText("/Shaders/Internal/default-sprite.wgsl");
var compileResult = clyde._shaderCompiler.CompileToWgsl(
new ResPath("/EngineShaders/Internal/default_sprite.wgsl"),
ImmutableDictionary<string, bool>.Empty);
using var shader = _rhi.CreateShaderModule(new RhiShaderModuleDescriptor(
shaderSource,
"default-sprite.wgsl"
using var shader = _rhi.CreateShaderModule(new RhiShaderModuleDescriptorUtf8(
compileResult.Code,
"default_sprite.wgsl"
));
_group0Layout = _rhi.CreateBindGroupLayout(new RhiBindGroupLayoutDescriptor(

View File

@@ -46,6 +46,7 @@ namespace Robust.Client.Graphics.Clyde
[Dependency] private readonly ClientEntityManager _entityManager = default!;
[Dependency] private readonly IPrototypeManager _proto = default!;
[Dependency] private readonly IReloadManager _reloads = default!;
[Dependency] private readonly IShaderCompiler _shaderCompiler = default!;
private bool _drawingSplash = true;

View File

@@ -0,0 +1,29 @@
using System.Collections.Generic;
using Robust.Shared.Utility;
namespace Robust.Client.Graphics;
public interface IShaderCompiler
{
ShaderCompileResultWgsl CompileToWgsl(ResPath path, IReadOnlyDictionary<string, bool> features);
}
public abstract class ShaderCompileResult
{
public bool Success { get; }
private protected ShaderCompileResult(bool success)
{
Success = success;
}
}
public sealed class ShaderCompileResultWgsl : ShaderCompileResult
{
public byte[] Code { get; }
internal ShaderCompileResultWgsl(byte[] code, bool success) : base(success)
{
Code = code;
}
}

View File

@@ -0,0 +1,212 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using Robust.Client.Interop.RobustNative.Wesl.Gen;
using Robust.Shared.Collections;
using Robust.Shared.ContentPack;
using Robust.Shared.Log;
using Robust.Shared.Utility;
namespace Robust.Client.Graphics;
internal sealed class ShaderCompiler : IShaderCompiler, IDisposable
{
private readonly IResourceManager _resourceManager;
private readonly ISawmill _sawmill;
private bool _disposed;
private GCHandle _resolverGCHandle;
private readonly ReaderWriterLockSlim _rwLock = new();
private ValueList<PackageDefinition> _packages;
public ShaderCompiler(IResourceManager resourceManager, ILogManager logManager)
{
_resolverGCHandle = GCHandle.Alloc(this);
_resourceManager = resourceManager;
_sawmill = logManager.GetSawmill("shader");
RegisterPackage(new ResPath("/EngineShaders"), "Robust");
RegisterPackage(new ResPath("/Shaders"), "Content");
}
public unsafe ShaderCompileResultWgsl CompileToWgsl(ResPath path, IReadOnlyDictionary<string, bool> features)
{
using var _ = _rwLock.ReadGuard();
CheckDisposed();
var resolverOptions = MakeResolverOptions();
var compileOptions = new WeslCompileOptions
{
resolver = &resolverOptions,
condcomp = 1,
imports = 1,
lower = 1,
};
var modulePath = ResPathToModulePath(path);
byte[] modulePathNullTerminated = [.. Encoding.UTF8.GetBytes(modulePath), 0];
WeslResult result;
fixed (byte* pPath = modulePathNullTerminated)
{
result = Wesl.wesl_compile(null, (sbyte*)pPath, &compileOptions, null, null);
}
try
{
if (result.success == 0)
{
var message = Marshal.PtrToStringUTF8((nint)result.error.message);
throw new Exception(message);
return new ShaderCompileResultWgsl([], false);
}
var data = MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)result.data);
return new ShaderCompileResultWgsl(data.ToArray(), true);
}
finally
{
Wesl.wesl_free_result(&result);
}
}
private unsafe WeslResolverOptions MakeResolverOptions()
{
return new WeslResolverOptions
{
resolve_source = &ResolveSource,
resolve_source_free = &ResolveSourceFree,
userdata = (void*)GCHandle.ToIntPtr(_resolverGCHandle),
};
}
private record struct ShaderModule(ResPath ResourcePath, string ModulePath);
public void Dispose()
{
using var _ = _rwLock.WriteGuard();
CheckDisposed();
_disposed = true;
_resolverGCHandle.Free();
}
private void CheckDisposed()
{
if (_disposed)
throw new ObjectDisposedException(nameof(ShaderCompiler));
}
private Stream? ResolveModulePath(string modulePath)
{
var resPath = ModulePathToResPath(modulePath);
if (!resPath.HasValue)
return null;
if (_resourceManager.TryContentFileRead(resPath, out var stream))
return stream;
// Try .wgsl as fallback
resPath = resPath.Value.WithExtension("wgsl");
return _resourceManager.ContentFileReadOrNull(resPath.Value);
}
private ResPath? ModulePathToResPath(string modulePath)
{
var components = modulePath.Split("::");
var packageName = components[0].Split('/')[^1];
foreach (var package in _packages)
{
if (package.Name == packageName)
return package.BasePath / $"{string.Join('/', components[1..])}.wesl";
}
return null;
}
private string ResPathToModulePath(ResPath resPath)
{
if (resPath.Extension is not ("wesl" or "wgsl"))
throw new ArgumentException("Shader path must end in .wesl or .wgsl");
foreach (var package in _packages)
{
if (resPath.TryRelativeTo(package.BasePath, out var relative))
{
var path = string.Join("::", relative.Value.EnumerateSegments());
return $"{package.Name}::" + path[..^5]; // Trim .wesl or .wgsl suffix.
}
}
throw new ArgumentException("Shader path must be inside a proper shader package");
}
private void RegisterPackage(ResPath root, string packageName)
{
_packages.Add(new PackageDefinition
{
Name = packageName, BasePath = root
});
}
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
private static unsafe WeslResolveSourceResult* ResolveSource(sbyte* modulePath, void* userdata)
{
var self = (ShaderCompiler)GCHandle.FromIntPtr((nint)userdata).Target!;
var path = Marshal.PtrToStringUTF8((nint)modulePath);
var stream = self.ResolveModulePath(path!);
var result = (WeslResolveSourceResult*)NativeMemory.Alloc((nuint)sizeof(WeslResolveSourceResult));
if (stream == null)
{
*result = new WeslResolveSourceResult
{
success = 0
};
}
else
{
var bytes = stream.CopyToArray();
var nativeSource = (byte*)NativeMemory.Alloc((nuint)bytes.Length + 1);
bytes.CopyTo(new Span<byte>(nativeSource, bytes.Length));
nativeSource[bytes.Length] = 0; // Null terminator
*result = new WeslResolveSourceResult
{
success = 1,
source = (sbyte*)nativeSource,
};
}
return result;
}
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
private static unsafe void ResolveSourceFree(WeslResolveSourceResult* result, void* userdata)
{
if (result->success != 0)
NativeMemory.Free(result->source);
NativeMemory.Free(result);
}
private sealed class PackageDefinition
{
public required string Name;
public required ResPath BasePath;
}
}

View File

@@ -140,7 +140,7 @@ namespace Robust.Client.Graphics
case "canvas":
Kind = ShaderKind.Canvas;
_source = IoCManager.Resolve<IResourceCache>().GetResource<ShaderSourceResource>("/Shaders/Internal/default-sprite.swsl");
_source = IoCManager.Resolve<IResourceCache>().GetResource<ShaderSourceResource>("/EngineShaders/Internal/default-sprite.swsl");
break;
default:

1
native/Cargo.lock generated
View File

@@ -1422,6 +1422,7 @@ dependencies = [
name = "wesl-c"
version = "0.2.0"
dependencies = [
"bindgen",
"wesl",
]