Compare commits

..

1 Commits

Author SHA1 Message Date
Ruben Ortlam 037397792a vulkan: split ggml-vulkan.cpp file 2026-06-22 15:50:01 +02:00
42 changed files with 29484 additions and 29723 deletions
+3
View File
@@ -341,6 +341,9 @@ set(GGML_PUBLIC_HEADERS
include/gguf.h)
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
#if (GGML_METAL)
# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
#endif()
install(TARGETS ggml LIBRARY PUBLIC_HEADER)
install(TARGETS ggml-base LIBRARY)
+51 -119
View File
@@ -24,119 +24,62 @@ if (GGML_METAL_NDEBUG)
endif()
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
set(METALLIB_KERNELS_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/kernels/common.h")
set(METALLIB_KERNELS_DEQUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/dequantize.h")
set(METALLIB_KERNELS_QUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/quantize.h")
set(METALLIB_KERNEL_SOURCES
kernels/fa.metal
kernels/mul_mv.metal
kernels/mul_mm.metal
kernels/quantize.metal
kernels/softmax.metal
kernels/norm.metal
kernels/unary.metal
kernels/binbcast.metal
kernels/reduce.metal
kernels/tri.metal
kernels/ssm.metal
kernels/wkv.metal
kernels/gated_delta_net.metal
kernels/solve_tri.metal
kernels/rope.metal
kernels/conv.metal
kernels/upscale.metal
kernels/argsort.metal
kernels/pool.metal
kernels/misc.metal
)
if (GGML_METAL_EMBED_LIBRARY)
enable_language(ASM)
add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
set(METALLIB_EMBED_ASM_FILES "")
foreach(src ${METALLIB_KERNEL_SOURCES})
get_filename_component(kind ${src} NAME_WE)
# symbol names must be valid C identifiers ('-' is not allowed)
string(REPLACE "-" "_" kind_sym ${kind})
# merge ggml-common.h and ggml-metal.metal into a single file
set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
set(SRC "${CMAKE_CURRENT_SOURCE_DIR}/kernels/${kind}.metal")
set(EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.metal")
set(ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.s")
add_custom_command(
OUTPUT "${METALLIB_EMBED_ASM}"
COMMAND echo "Embedding Metal library"
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
COMMENT "Generate assembly for embedded Metal library"
VERBATIM
)
# only prepend headers that this source actually includes
set(HEADERS_FOR_SRC ${METALLIB_KERNELS_COMMON})
file(STRINGS ${SRC} _has_dequantize REGEX "#include \"dequantize\\.h\"")
file(STRINGS ${SRC} _has_quantize REGEX "#include \"quantize\\.h\"")
if(_has_dequantize)
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_DEQUANTIZE})
endif()
if(_has_quantize)
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_QUANTIZE})
endif()
add_custom_command(
OUTPUT "${ASM}"
# Step 1: concatenate shared headers + this kernel source
COMMAND cat ${HEADERS_FOR_SRC} ${SRC} > "${EMBED}.tmp1"
# Step 2: remove internal #include and #pragma once
COMMAND sed -e "/\#include \"common.h\"/d" -e "/\#include \"dequantize.h\"/d" -e "/\#include \"quantize.h\"/d" -e "/\#pragma once/d" < "${EMBED}.tmp1" > "${EMBED}.tmp2"
# Step 3: inline ggml-common.h (replacing __embed_ggml-common.h__ sentinel)
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${EMBED}.tmp2" > "${EMBED}.tmp3"
# Step 4: inline ggml-metal-impl.h
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${EMBED}.tmp3" > "${EMBED}"
# Step 5: emit an asm chunk with kind-specific start/end symbols
# note: '-' is illegal in C symbols, so we use kind_sym; the macOS
# section name is limited to 16 chars so we keep it shared
# across kinds (__ggml_metallib) and only vary the global symbols.
COMMAND echo ".section __DATA,__ggml_metallib" > "${ASM}"
COMMAND echo ".globl _ggml_metallib_${kind_sym}_start" >> "${ASM}"
COMMAND echo "_ggml_metallib_${kind_sym}_start:" >> "${ASM}"
COMMAND echo .incbin "\"${EMBED}\"" >> "${ASM}"
COMMAND echo ".globl _ggml_metallib_${kind_sym}_end" >> "${ASM}"
COMMAND echo "_ggml_metallib_${kind_sym}_end:" >> "${ASM}"
DEPENDS ../ggml-common.h ggml-metal-impl.h
kernels/common.h kernels/dequantize.h kernels/quantize.h
kernels/${kind}.metal
COMMENT "Generate embedded Metal library for ${kind}"
VERBATIM
)
list(APPEND METALLIB_EMBED_ASM_FILES "${ASM}")
endforeach()
target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM_FILES})
target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
else()
# copy header files to bin directory
# copy metal files to bin directory
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
file(MAKE_DIRECTORY "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels")
configure_file(kernels/common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/common.h COPYONLY)
configure_file(kernels/dequantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/dequantize.h COPYONLY)
configure_file(kernels/quantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/quantize.h COPYONLY)
foreach(src ${METALLIB_KERNEL_SOURCES})
configure_file(${src} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} COPYONLY)
endforeach()
if (GGML_METAL_SHADER_DEBUG)
# note: disabling fast math is needed in order to pass tests/test-backend-ops
# custom command to do the following:
# xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
# xcrun -sdk macosx metallib ggml-metal.air -o default.metallib
#
# note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
# disabling fast math is needed in order to pass tests/test-backend-ops
# note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
# note: unfortunately, we have to call it default.metallib instead of ggml.metallib
# ref: https://github.com/ggml-org/whisper.cpp/issues/1720
# note: adding -g causes segmentation fault during compile
#set(XC_FLAGS -fno-fast-math -fno-inline -g)
set(XC_FLAGS -fno-fast-math -fno-inline)
else()
set(XC_FLAGS -O3)
endif()
# Append macOS metal versioning flags
if (GGML_METAL_MACOSX_VERSION_MIN)
message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation")
list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})
@@ -147,46 +90,35 @@ else()
list (APPEND XC_FLAGS -std=${GGML_METAL_STD})
endif()
# Compile each kernel source to .air, then link into default.metallib
set(AIR_FILES "")
foreach(src ${METALLIB_KERNEL_SOURCES})
get_filename_component(name ${src} NAME_WE)
set(AIR "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${name}.air")
list(APPEND AIR_FILES ${AIR})
add_custom_command(
OUTPUT ${AIR}
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -I ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} -o ${AIR}
DEPENDS ${src} kernels/common.h kernels/dequantize.h kernels/quantize.h ${METALLIB_COMMON} ggml-metal-impl.h
COMMENT "Compiling ${src}"
VERBATIM
)
endforeach()
add_custom_command(
OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND xcrun -sdk macosx metallib ${AIR_FILES} -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h
COMMAND rm -rf ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels
DEPENDS ${AIR_FILES}
COMMENT "Linking Metal kernels into default.metallib"
)
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
DEPENDS ggml-metal.metal ${METALLIB_COMMON}
COMMENT "Compiling Metal kernels"
)
# FIXME: only add to the ggml-metal target?
add_custom_target(
ggml-metal-lib ALL
DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
)
)
endif() # GGML_METAL_EMBED_LIBRARY
if (NOT GGML_METAL_EMBED_LIBRARY)
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/kernels/
DESTINATION ${CMAKE_INSTALL_BINDIR}/kernels
FILES_MATCHING PATTERN "*.metal" PATTERN "*.h"
)
FILES src/ggml-metal/ggml-metal.metal
PERMISSIONS
OWNER_READ
OWNER_WRITE
GROUP_READ
WORLD_READ
DESTINATION ${CMAKE_INSTALL_BINDIR})
install(
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
DESTINATION ${CMAKE_INSTALL_BINDIR}
)
install(
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
DESTINATION ${CMAKE_INSTALL_BINDIR}
)
endif()
+127 -422
View File
@@ -94,63 +94,8 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi
return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
}
//
// MTLLibrary collection (one library per op-source, compiled separately)
//
// Single source of truth for the per-kind metal libraries. The order here
// defines the enum values and every per-kind table below, so adding a library
// is a one-line change here (plus adding its source to CMakeLists.txt).
// X(suffix, name): name is both the kernels/<name>.metal basename and the
// ggml_metallib_<name>_{start,end} embed-symbol stem.
#define GGML_METAL_LIBS \
X(FA, fa) \
X(MUL_MV, mul_mv) \
X(MUL_MM, mul_mm) \
X(QUANTIZE, quantize) \
X(SOFTMAX, softmax) \
X(NORM, norm) \
X(UNARY, unary) \
X(BINBCAST, binbcast) \
X(REDUCE, reduce) \
X(TRI, tri) \
X(SSM, ssm) \
X(WKV, wkv) \
X(GATED_DELTA_NET, gated_delta_net)\
X(SOLVE_TRI, solve_tri) \
X(ROPE, rope) \
X(CONV, conv) \
X(UPSCALE, upscale) \
X(ARGSORT, argsort) \
X(POOL, pool) \
X(MISC, misc)
enum ggml_metal_lib_kind {
#define X(e, s) GGML_METAL_LIB_##e,
GGML_METAL_LIBS
#undef X
GGML_METAL_LIB_COUNT,
};
static const char * const k_lib_names[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = #s,
GGML_METAL_LIBS
#undef X
};
struct ggml_metal_library {
// Per-kind compiled libraries. When single_library is true, the whole library
// (e.g. a pre-compiled default.metallib or a from-source build) lives at
// objs[0] and the remaining slots are nil.
id<MTLLibrary> objs[GGML_METAL_LIB_COUNT];
bool single_library; // true: combined library at objs[0]; false: per-kind libs in objs[*]
// Routing table: kernel function name -> objs[] index, populated from each
// compiled library's -[MTLLibrary functionNames]. The actual compiled
// libraries are the single source of truth for which library owns a kernel,
// so adding kernels later requires no manual routing maintenance.
// nil in single_library mode (everything resolves to objs[0]).
NSMutableDictionary<NSString *, NSNumber *> * fn_to_lib;
id<MTLLibrary> obj;
ggml_metal_device_t dev;
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
@@ -158,376 +103,160 @@ struct ggml_metal_library {
NSLock * lock;
};
// Build the fn_to_lib routing table by querying each compiled library's public
// function names. Call once after all per-kind libraries have been compiled.
static void ggml_metal_library_build_index(ggml_metal_library_t lib) {
@autoreleasepool {
NSMutableDictionary<NSString *, NSNumber *> * index = [[NSMutableDictionary alloc] init];
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
for (NSString * fname in [lib->objs[kind] functionNames]) {
index[fname] = @(kind);
}
}
lib->fn_to_lib = index;
}
}
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
id<MTLLibrary> library = nil;
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
// Parse a `#include "name"` line. Returns the quoted name in *include_name on
// success. Whitespace-tolerant; ignores `#include <...>` (system headers).
static bool ggml_metal_library_parse_quoted_include(NSString * line, NSString ** include_name) {
NSScanner * scanner = [NSScanner scannerWithString:line];
scanner.charactersToBeSkipped = [NSCharacterSet whitespaceCharacterSet];
// load library
//
// - first check if the library is embedded
// - then check if the library is in the bundle
// - if not found, load the source and compile it
// - if that fails, return NULL
//
// TODO: move to a function
{
const int64_t t_start = ggml_time_us();
if (![scanner scanString:@"#" intoString:NULL] ||
![scanner scanString:@"include" intoString:NULL] ||
![scanner scanString:@"\"" intoString:NULL]) {
return false;
}
NSError * error = nil;
NSString * src = nil;
NSString * name = nil;
if (![scanner scanUpToString:@"\"" intoString:&name]) {
return false;
}
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
if (include_name) {
*include_name = name;
}
return true;
}
extern const char ggml_metallib_start[];
extern const char ggml_metallib_end[];
// Recursively inline `#include "name"` directives. System includes (<...>),
// `#if/#else/#endif`, and other preprocessor lines are passed through to the
// Metal compiler unchanged. `#pragma once` is dropped since `seen` already
// guards against double-inclusion.
static bool ggml_metal_library_flatten_file(NSMutableString * dst, NSString * path,
NSArray<NSString *> * search_paths,
NSMutableSet<NSString *> * seen, NSError ** error) {
NSString * key = [path stringByStandardizingPath];
if ([seen containsObject:key]) {
return true;
}
[seen addObject:key];
src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
#else
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:error];
if (!src) {
return false;
}
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
NSFileManager * fm = [NSFileManager defaultManager];
for (NSString * line in [src componentsSeparatedByString:@"\n"]) {
NSString * trimmed = [line stringByTrimmingCharactersInSet:[NSCharacterSet whitespaceCharacterSet]];
if ([trimmed isEqualToString:@"#pragma once"]) {
continue;
}
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
NSString * include_name = nil;
if (ggml_metal_library_parse_quoted_include(line, &include_name)) {
NSString * resolved = nil;
for (NSString * dir in search_paths) {
NSString * candidate = [dir stringByAppendingPathComponent:include_name];
if ([fm isReadableFileAtPath:candidate]) {
resolved = candidate;
break;
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
}
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
// Link to the resource could not be resolved.
path_lib_default = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
}
}
} else {
// The resource couldn't be found in the binary's directory.
path_lib_default = nil;
}
if (!resolved) {
if (error) {
NSString * msg = [NSString stringWithFormat:@"could not resolve include \"%@\" from '%@'", include_name, path];
*error = [NSError errorWithDomain:@"ggml-metal-source-flatten" code:1
userInfo:@{NSLocalizedDescriptionKey: msg}];
}
return false;
}
if (!ggml_metal_library_flatten_file(dst, resolved, search_paths, seen, error)) {
return false;
}
continue;
path_lib = path_lib_default;
}
[dst appendString:line];
[dst appendString:@"\n"];
}
if (path_lib != nil) {
// pre-compiled library found
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
return true;
}
library = [device newLibraryWithURL:libURL error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
} else {
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
static NSString * ggml_metal_library_flatten_source(NSString * path_source, NSError ** error) {
// Search paths cover both runtime layout (build/bin/kernels + build/bin)
// and source-tree layout (ggml/src/ggml-metal/kernels + ggml/src/ggml-metal + ggml/src).
NSString * path_kernels = [path_source stringByDeletingLastPathComponent];
NSString * path_base = [path_kernels stringByDeletingLastPathComponent];
NSArray<NSString *> * search_paths = @[
path_kernels,
path_base,
[path_base stringByDeletingLastPathComponent],
];
NSString * path_source;
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
NSMutableString * src = [[NSMutableString alloc] init];
NSMutableSet<NSString *> * seen = [NSMutableSet set];
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
if (!ggml_metal_library_flatten_file(src, path_source, search_paths, seen, error)) {
[src release];
return nil;
}
return src;
}
// Compile all per-kind libraries in parallel. `source_for_kind` returns the MSL
// source for a kind (the helper takes ownership and releases it), or nil with
// *err set on failure. On success the objs[] slots are populated and the routing
// index is built; on any failure every error is logged and false is returned
// (the caller is responsible for freeing `res`).
static bool ggml_metal_library_compile_all(
ggml_metal_library_t res,
id<MTLDevice> device,
NSDictionary * prep,
NSString * (^source_for_kind)(int kind, NSError ** err),
const char * origin) {
const int64_t t_start = ggml_time_us();
int64_t * t_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(int64_t));
NSError ** err_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(NSError *));
__block atomic_bool any_failure = false;
dispatch_group_t group = dispatch_group_create();
dispatch_queue_t queue = dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0);
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
dispatch_group_async(group, queue, ^{
const int64_t t0 = ggml_time_us();
NSError * error = nil;
NSString * src = source_for_kind(kind, &error);
if (!src) {
err_per_lib[kind] = [error retain];
atomic_store(&any_failure, true);
return;
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
} else {
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
}
id<MTLLibrary> lib = nil;
if (path_source == nil) {
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
path_source = @"ggml-metal.metal";
}
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
}
#endif
if (!library) {
@autoreleasepool {
// dictionary of preprocessor macros
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (ggml_metal_device_get_props(dev)->has_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
}
if (ggml_metal_device_get_props(dev)->has_tensor) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
MTLCompileOptions * options = [MTLCompileOptions new];
options.preprocessorMacros = prep;
lib = [device newLibraryWithSource:src options:options error:&error];
//[options setFastMathEnabled:false];
[options release];
// retain the error before the autorelease pool drains it
if (!lib) {
err_per_lib[kind] = [error retain];
library = [device newLibraryWithSource:src options:options error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
}
[src release];
t_per_lib[kind] = ggml_time_us() - t0;
if (!lib) {
atomic_store(&any_failure, true);
return;
}
res->objs[kind] = lib;
});
}
dispatch_group_wait(group, DISPATCH_TIME_FOREVER);
dispatch_release(group);
const bool ok = !atomic_load(&any_failure);
if (ok) {
const int64_t t_total = ggml_time_us() - t_start;
int64_t t_max = 0;
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
GGML_LOG_DEBUG("%s: compiled '%s' library in %.3f sec\n",
__func__, k_lib_names[kind], t_per_lib[kind] / 1e6);
if (t_per_lib[kind] > t_max) t_max = t_per_lib[kind];
}
GGML_LOG_INFO("%s: loaded %d libraries from %s in %.3f sec (max single = %.3f sec)\n",
__func__, GGML_METAL_LIB_COUNT, origin, t_total / 1e6, t_max / 1e6);
ggml_metal_library_build_index(res);
} else {
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
if (err_per_lib[kind]) {
GGML_LOG_ERROR("%s: failed to build '%s' library: %s\n", __func__,
k_lib_names[kind], [[err_per_lib[kind] description] UTF8String]);
[err_per_lib[kind] release];
#if !__has_feature(objc_arc)
[options release];
#endif
}
}
#if GGML_METAL_EMBED_LIBRARY
[src release];
#endif // GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
}
free(err_per_lib);
free(t_per_lib);
return ok;
}
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
res->obj = library;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
// shared MTLCompileOptions preprocessor macros (matches the build-time defines)
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (ggml_metal_device_get_props(dev)->has_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
}
if (ggml_metal_device_get_props(dev)->has_tensor) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
// start/end symbols emitted by CMake (see CMakeLists.txt), one pair per kind
#define X(e, s) extern const char ggml_metallib_##s##_start[]; extern const char ggml_metallib_##s##_end[];
GGML_METAL_LIBS
#undef X
static const char * const lib_start[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_start,
GGML_METAL_LIBS
#undef X
};
static const char * const lib_end[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_end,
GGML_METAL_LIBS
#undef X
};
const bool ok = ggml_metal_library_compile_all(res, device, prep,
^NSString * (int kind, NSError ** err) {
(void) err;
return [[NSString alloc] initWithBytes:lib_start[kind]
length:(lib_end[kind] - lib_start[kind])
encoding:NSUTF8StringEncoding];
}, "embedded data");
if (!ok) {
ggml_metal_library_free(res);
return NULL;
}
return res;
#else
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
const int64_t t_start = ggml_time_us();
NSError * error = nil;
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
}
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
// Link to the resource could not be resolved.
path_lib_default = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
}
}
} else {
// The resource couldn't be found in the binary's directory.
path_lib_default = nil;
}
path_lib = path_lib_default;
}
if (path_lib != nil) {
// pre-compiled library found: a single combined default.metallib
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
res->objs[0] = [device newLibraryWithURL:libURL error:&error];
res->single_library = true;
if (!res->objs[0]) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
ggml_metal_library_free(res);
return NULL;
}
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
return res;
}
// no pre-compiled metallib: fall back to compiling each kernel source separately
GGML_LOG_INFO("%s: default.metallib not found, loading kernel sources\n", __func__);
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
if (path_resource) {
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, [path_resource UTF8String]);
}
// resolve each kind's source path up front (file lookup/logging stays on the calling thread)
NSString ** path_per_kind = calloc(GGML_METAL_LIB_COUNT, sizeof(NSString *));
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
NSString * rel = [NSString stringWithFormat:@"kernels/%s.metal", k_lib_names[kind]];
NSString * path_source = nil;
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:rel];
} else {
NSString * stem = [NSString stringWithFormat:@"kernels/%s", k_lib_names[kind]];
path_source = [bundle pathForResource:stem ofType:@"metal"];
}
if (path_source == nil || ![[NSFileManager defaultManager] isReadableFileAtPath:path_source]) {
GGML_LOG_WARN("%s: could not locate %s in bundle, falling back to cwd\n", __func__, [rel UTF8String]);
path_source = rel;
}
GGML_LOG_DEBUG("%s: loading '%s'\n", __func__, [path_source UTF8String]);
path_per_kind[kind] = [path_source retain];
}
const bool ok = ggml_metal_library_compile_all(res, device, prep,
^NSString * (int kind, NSError ** err) {
return ggml_metal_library_flatten_source(path_per_kind[kind], err);
}, "source");
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
[path_per_kind[kind] release];
}
free(path_per_kind);
if (!ok) {
ggml_metal_library_free(res);
return NULL;
}
return res;
#endif
}
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
@@ -589,11 +318,10 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
return NULL;
}
res->objs[0] = library;
res->single_library = true;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
res->obj = library;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
return res;
}
@@ -603,14 +331,8 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
return;
}
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
if (lib->objs[kind]) {
[lib->objs[kind] release];
}
}
if (lib->fn_to_lib) {
[lib->fn_to_lib release];
if (lib->obj) {
[lib->obj release];
}
ggml_metal_pipelines_free(lib->pipelines);
@@ -671,28 +393,11 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name);
// route to the library that actually defines this kernel; fn_to_lib is
// built from -[MTLLibrary functionNames] so it's always in sync
int lib_idx = 0;
if (!lib->single_library) {
NSNumber * idx = lib->fn_to_lib[base_func];
if (!idx) {
[lib->lock unlock];
GGML_LOG_ERROR("%s: kernel not found in any metal library: base = '%s', name = '%s'\n", __func__, base, name);
return res;
}
lib_idx = [idx intValue];
}
id<MTLLibrary> mtl_lib = lib->objs[lib_idx];
id<MTLFunction> mtl_function;
if (!cv) {
mtl_function = [mtl_lib newFunctionWithName:base_func];
mtl_function = [lib->obj newFunctionWithName:base_func];
} else {
mtl_function = [mtl_lib newFunctionWithName:base_func constantValues:cv->obj error:&error];
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
}
if (!mtl_function) {
[lib->lock unlock];
File diff suppressed because it is too large Load Diff
-232
View File
@@ -1,232 +0,0 @@
#include "common.h"
// bitonic sort implementation following the CUDA kernels as reference
typedef void (argsort_t)(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_f32_i32(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort
const int col = tpitg[0];
const int ib = tgpig[0] / args.ne01;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
// initialize indices
shmem_i32[col] = i00 + col;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int k = 2; k <= ntg.x; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (shmem_i32[col] >= args.ne00 ||
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
) {
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
} else {
if (shmem_i32[ixj] >= args.ne00 ||
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
) {
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
const int64_t i0 = ib*args.top_k;
// copy the result to dst without the padding
if (i0 + col < args.ne0 && col < args.top_k) {
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
dst[col] = shmem_i32[col];
}
}
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
typedef void (argsort_merge_t)(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_merge_f32_i32(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int im = tgpig[0] / args.ne01;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
const int total = len0 + len1;
device const int32_t * tmp0 = tmp + start
+ i01*args.ne0
+ i02*args.ne0*args.ne01
+ i03*args.ne0*args.ne01*args.ne02;
device const int32_t * tmp1 = tmp0 + args.len;
dst += start
+ i01*args.top_k
+ i02*args.top_k*args.ne01
+ i03*args.top_k*args.ne01*args.ne02;
device const float * src0_row = (device const float *)(src0
+ args.nb01*i01
+ args.nb02*i02
+ args.nb03*i03);
if (total == 0) {
return;
}
const int chunk = (total + ntg.x - 1) / ntg.x;
const int k0 = tpitg.x * chunk;
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
if (k0 >= args.top_k) {
return;
}
if (k0 >= total) {
return;
}
int low = k0 > len1 ? k0 - len1 : 0;
int high = MIN(k0, len0);
// binary-search partition (i, j) such that i + j = k
while (low < high) {
const int mid = (low + high) >> 1;
const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k0 - mid - 1];
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
if (take_left) {
low = mid + 1;
} else {
high = mid;
}
}
int i = low;
int j = k0 - i;
// keep the merge fronts into registers
int32_t idx0 = 0;
float val0 = 0.0f;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
int32_t idx1 = 0;
float val1 = 0.0f;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
for (int k = k0; k < k1; ++k) {
int32_t out_idx;
if (i >= len0) {
while (k < k1) {
dst[k++] = tmp1[j++];
}
break;
} else if (j >= len1) {
while (k < k1) {
dst[k++] = tmp0[i++];
}
break;
} else {
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
if (take_left) {
out_idx = idx0;
++i;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
} else {
out_idx = idx1;
++j;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
}
}
dst[k] = out_idx;
}
}
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
-226
View File
@@ -1,226 +0,0 @@
#include "common.h"
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
template <typename T0, typename T1, typename T>
kernel void kernel_bin_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_bin_op
#define FC_F FC_bin_f
#define FC_RB FC_bin_rb
#define FC_CB FC_bin_cb
if (FC_RB) {
// row broadcast
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
device const T0 * src0_row = (device const T0 *) (src0);
device T * dst_row = (device T *) (dst);
if (FC_F == 1) {
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
if (FC_OP == 0) {
dst_row[i0] = src0_row[i0] + src1_row[i1];
}
if (FC_OP == 1) {
dst_row[i0] = src0_row[i0] - src1_row[i1];
}
if (FC_OP == 2) {
dst_row[i0] = src0_row[i0] * src1_row[i1];
}
if (FC_OP == 3) {
dst_row[i0] = src0_row[i0] / src1_row[i1];
}
} else {
T0 res = src0_row[i0];
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
dst_row[i0] = res;
}
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (i01 >= args.ne01) {
return;
}
const int i13 = i03%args.ne13;
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
if (FC_F == 1) {
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = FC_CB ? i0%args.ne10 : i0;
if (FC_OP == 0) {
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
}
if (FC_OP == 1) {
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
}
if (FC_OP == 2) {
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
}
if (FC_OP == 3) {
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
}
}
} else {
device const T1 * src1_ptr[8];
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = FC_CB ? i0%args.ne10 : i0;
T res = src0_ptr[i0];
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += src1_ptr[j][i10];
}
}
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= src1_ptr[j][i10];
}
}
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= src1_ptr[j][i10];
}
}
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= src1_ptr[j][i10];
}
}
dst_ptr[i0] = res;
}
}
}
#undef FC_OP
#undef FC_F
#undef FC_RB
#undef FC_CB
}
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
kernel void kernel_add_id(
constant ggml_metal_kargs_add_id & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i1 = tgpig.x;
const int i2 = tgpig.y;
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
const size_t nb1 = args.ne0 * sizeof(float);
const size_t nb2 = args.ne1 * nb1;
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
template<typename T>
kernel void kernel_repeat(
constant ggml_metal_kargs_repeat & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
const int i03 = i3%args.ne03;
const int i02 = i2%args.ne02;
const int i01 = i1%args.ne01;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i00 = i0%args.ne00;
*((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
}
}
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_repeat_bf16")]] kernel kernel_repeat_t kernel_repeat<bfloat>;
#endif
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
-126
View File
@@ -1,126 +0,0 @@
#pragma once
#include "ggml-metal-impl.h"
#include <metal_stdlib>
#ifdef GGML_METAL_HAS_TENSOR
#include <metal_tensor>
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
#endif
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
//
// cmd:
// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/kernels/<src>.metal
// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/kernels/<src>.metal
//
#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)
#undef GGML_METAL_HAS_BF16
#endif
#if defined(GGML_METAL_HAS_BF16)
typedef matrix<bfloat, 4, 4> bfloat4x4;
typedef matrix<bfloat, 2, 4> bfloat2x4;
#endif
constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
constexpr constant static float kvalues_mxfp4_f[16] = {
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
};
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static inline float e8m0_to_fp32(uint8_t x) {
uint32_t bits;
if (x == 0) {
bits = 0x00400000;
} else {
bits = (uint32_t) x << 23;
}
return as_type<float>(bits);
}
static inline float dot(float x, float y) {
return x*y;
}
static inline float sum(float x) {
return x;
}
static inline float sum(float4 x) {
return x[0] + x[1] + x[2] + x[3];
}
enum ggml_sort_order {
GGML_SORT_ORDER_ASC,
GGML_SORT_ORDER_DESC,
};
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
inline T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
template<typename T> T elu_approx(T x);
template<> inline float elu_approx<float>(float x) {
return (x > 0.f) ? x : (exp(x) - 1);
}
template<> inline float4 elu_approx<float4>(float4 x) {
float4 res;
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
return res;
}
-485
View File
@@ -1,485 +0,0 @@
#include "common.h"
typedef void (im2col_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// const int64_t IC = tgpg[0];
const int64_t OH = tgpg[1];
const int64_t OW = tgpg[2];
const int64_t KH = ntg[1];
const int64_t KW = ntg[2];
int64_t in = tpitg[0];
const int64_t ikh = tpitg[1];
const int64_t ikw = tpitg[2];
const int64_t iic = tgpig[0];
const int64_t ioh = tgpig[1];
const int64_t iow = tgpig[2];
const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
while (in < args.N) {
pdst[offset_dst] = 0.0f;
offset_dst += ntg[0]*args.CHW*OH*OW;
in += ntg[0];
}
} else {
int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
while (in < args.N) {
pdst[offset_dst] = x[offset_src];
offset_dst += ntg[0]*args.CHW*OH*OW;
offset_src += ntg[0]*args.ofs0;
in += ntg[0];
}
}
}
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
// TODO: optimize
typedef void (im2col_ext_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col_ext(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int64_t KHW = (int64_t)args.KHW;
const int64_t d = tgpig[0] / args.CHW;
const int64_t chw = tgpig[0] % args.CHW;
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int64_t HW = tgpig[0] % KHW;
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= args.N) {
return;
}
const int64_t tpitg_1 = HW / args.KW;
const int64_t tpitg_2 = HW % args.KW;
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
const int64_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
pdst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
}
}
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
template <typename TK>
kernel void kernel_conv_2d(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
const uint thread_index = tg_index * threads_per_tg + local_thread;
const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
uint64_t tmp = index;
const int32_t ow = tmp % args.OW; tmp /= args.OW;
const int32_t oh = tmp % args.OH; tmp /= args.OH;
const int32_t oc = tmp % args.OC; tmp /= args.OC;
const int32_t n = tmp;
float acc = 0.0f;
const int32_t base_x = ow*args.s0 - args.p0;
const int32_t base_y = oh*args.s1 - args.p1;
int32_t ky_start = 0;
if (base_y < 0) {
ky_start = (-base_y + args.d1 - 1)/args.d1;
}
int32_t ky_end = args.KH;
const int32_t y_max = args.IH - 1 - base_y;
if (y_max < 0) {
ky_end = ky_start;
} else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
ky_end = min(ky_end, y_max/args.d1 + 1);
}
int32_t kx_start = 0;
if (base_x < 0) {
kx_start = (-base_x + args.d0 - 1)/args.d0;
}
int32_t kx_end = args.KW;
const int32_t x_max = args.IW - 1 - base_x;
if (x_max < 0) {
kx_end = kx_start;
} else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
kx_end = min(kx_end, x_max/args.d0 + 1);
}
if (ky_start < ky_end && kx_start < kx_end) {
const uint64_t src_base_n = (uint64_t) n * args.nb13;
const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
for (int32_t ic = 0; ic < args.IC; ++ic) {
const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
for (int32_t ky = ky_start; ky < ky_end; ++ky) {
const int32_t iy = base_y + ky*args.d1;
const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
for (int32_t kx = kx_start; kx < kx_end; ++kx) {
const int32_t ix = base_x + kx*args.d0;
const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
const float x = *(device const float *)(src + src_offs);
const float w = (float) (*(device const TK *)(weights + w_offs));
acc += x * w;
}
}
}
}
const uint64_t dst_offs =
(uint64_t) n * args.nb3 +
(uint64_t) oc * args.nb2 +
(uint64_t) oh * args.nb1 +
(uint64_t) ow * args.nb0;
*(device float *)(dst + dst_offs) = acc;
}
}
template [[host_name("kernel_conv_2d_f32_f32")]]
kernel void kernel_conv_2d<float>(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_2d_f16_f32")]]
kernel void kernel_conv_2d<half>(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
typedef void (conv_transpose_1d_t)(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_1d(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const T * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]) {
// For output position j on the time axis, only input positions
// i such that i*s0 <= j < i*s0 + K
// contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
// intersected with [0, IL-1]. That's at most ceil(K/s0) values
// (typically 2 for stride==K/2 transposed convs).
const int32_t j = tgpig[0];
const int32_t s0 = args.s0;
const int32_t K = args.K;
const int32_t IL = args.IL;
int32_t i_min;
{
int32_t a = j - K + 1;
i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
}
int32_t i_max = j / s0;
if (i_max > IL - 1) i_max = IL - 1;
float v = 0.0f;
if (i_min <= i_max) {
for (int64_t c = 0; c < args.IC; c++) {
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
const int32_t input_offset = c * IL;
for (int32_t i = i_min; i <= i_max; i++) {
v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
}
}
}
device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
dst_ptr[0] = v;
}
template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
kernel void kernel_conv_transpose_1d<float>(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
kernel void kernel_conv_transpose_1d<half>(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const half * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
typedef void (conv_transpose_2d_t)(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_2d(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const T * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t out_x = tgpig[0];
const int64_t out_y = tgpig[1];
const int64_t out_c = tgpig[2];
const int64_t kw = tpitg[0];
const int64_t kh = tpitg[1];
float v = 0.0f;
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
int64_t in_y = out_y - kh;
if (in_y < 0 || in_y % args.s0) continue;
in_y /= args.s0;
if (in_y >= args.IH) continue;
int64_t in_x = out_x - kw;
if (in_x < 0 || in_x % args.s0) continue;
in_x /= args.s0;
if (in_x >= args.IW) continue;
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
v += (float)src0[kernel_idx] * src1[input_idx];
}
const uint tid = tpitg.y * ntg.x + tpitg.x;
shared_sum[tid] = v;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
const uint num_threads = ntg.x * ntg.y;
for (uint i = 0; i < num_threads; i++) {
total += shared_sum[i];
}
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
dst_ptr[0] = total;
}
}
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
kernel void kernel_conv_transpose_2d<float>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
kernel void kernel_conv_transpose_2d<half>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const half * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_conv_3d(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0, // Weights [IC * OC, KD, KH, KW]
device const char * src1, // Inputs [IC * N, ID, IH, IW]
device char * dst, // Outputs [OC * N, OD, OH, OW]
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
// 1. Un-flatten the spatial dimension from Grid X
int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
if (spatial_idx >= args.OW * args.OH * args.OD) {
return; // Thread falls outside the spatial volume
}
int64_t od = spatial_idx / (args.OW * args.OH);
int64_t oh = (spatial_idx / args.OW) % args.OH;
int64_t ow = spatial_idx % args.OW;
// 2. Map Y to Channels, Z to Batch
int64_t oc = tgpig.y;
int64_t batch_idx = tgpig.z;
// 3. Calculate anchor coordinates in the Input volume
int64_t i_w_base = ow * args.s0 - args.p0;
int64_t i_h_base = oh * args.s1 - args.p1;
int64_t i_d_base = od * args.s2 - args.p2;
float sum = 0.0f;
// 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
for (int64_t ic = 0; ic < args.IC; ++ic) {
// ggml packs batch and channel together in the 4th dimension
int64_t src_cn_idx = batch_idx * args.IC + ic;
int64_t w_cn_idx = oc * args.IC + ic;
for (int64_t kz = 0; kz < args.KD; ++kz) {
int64_t id = i_d_base + kz * args.d2;
if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
for (int64_t ky = 0; ky < args.KH; ++ky) {
int64_t ih = i_h_base + ky * args.d1;
if (ih < 0 || ih >= args.IH) continue;
for (int64_t kx = 0; kx < args.KW; ++kx) {
int64_t iw = i_w_base + kx * args.d0;
if (iw < 0 || iw >= args.IW) continue;
// Convert multi-dimensional coordinates to flat byte offsets
int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
// Dereference memory and cast weights to f32 if they were f16
float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
float i_val = *(device const float*)((device const char*)src1 + i_idx);
sum += w_val * i_val;
}
}
}
}
// 5. Write the accumulated value out to RAM
int64_t dst_cn_idx = batch_idx * args.OC + oc;
int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
*(device float*)(dst + d_idx) = sum;
}
// Explicit instantiations so the JIT compiler can find them by name
template [[host_name("kernel_conv_3d_f32_f32")]]
kernel void kernel_conv_3d<float>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
// Explicit instantiation for f16 weights
template [[host_name("kernel_conv_3d_f16_f32")]]
kernel void kernel_conv_3d<half>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
-686
View File
@@ -1,686 +0,0 @@
#pragma once
#include "common.h"
#define GGML_COMMON_DECL_METAL
#define GGML_COMMON_IMPL_METAL
#if defined(GGML_METAL_EMBED_LIBRARY)
__embed_ggml-common.h__
#else
#include "ggml-common.h"
#endif
#define QK_NL 16 // shared by mul_mm and get_rows_q instantiations
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
reg = (type4)(*src);
}
template <typename type4x4>
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#if defined(GGML_METAL_HAS_BF16)
template <typename type4x4>
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#endif
template <typename type4x4>
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
device const uint8_t * qs = xb->qs;
const float d = xb->d;
const float neg_d = -d;
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
const uint8_t b0 = qs[byte_offset];
const uint8_t b1 = qs[byte_offset + 1];
float4x4 reg_f;
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
const float d = xb->d;
const float neg_d = -d;
const int base = il * 4;
const uint8_t byte = xb->qs[base / 8];
const int s = base % 8;
float4 reg_f;
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
reg = (type4) reg_f;
}
template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
}
}
template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
}
}
template <typename type4x4>
void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = il ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = il ? 4 : 0;
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = (il/4) ? 4 : 0;
const int gh_mv = (il/4) ? 12 : 0;
const int gh_bk = (il/4) ? 0 : 4;
for (int ii = 0; ii < 2; ii++) {
int i = 2*(il%4) + ii;
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[2*ii + 0] = d * x0 + md;
reg[2*ii + 1] = d * x1 + md;
}
}
template <typename type4x4>
void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = il ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = il ? 4 : 0;
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = (il/4) ? 4 : 0;
const int gh_mv = (il/4) ? 12 : 0;
const int gh_bk = (il/4) ? 0 : 4;
for (int ii = 0; ii < 2; ii++) {
int i = 2*(il%4) + ii;
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[2*ii + 0] = d * x0 + m;
reg[2*ii + 1] = d * x1 + m;
}
}
template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const float d = xb->d;
float4x4 reg_f;
for (int i = 0; i < 16; i++) {
reg_f[i/4][i%4] = (qs[i + 16*il] * d);
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const float d = xb->d;
for (int i = 0; i < 4; i++) {
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
}
}
template <typename type4x4>
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
const float d = e8m0_to_fp32(xb->e);
const uint8_t shr = il >= 1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
}
}
template <typename type4>
void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
const float d = e8m0_to_fp32(xb->e);
const short il4 = il%4;
const uint8_t shr = il >= 4 ? 4 : 0;
reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
}
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d;
const float min = xb->dmin;
device const uint8_t * q = (device const uint8_t *)xb->qs;
float dl, ml;
uint8_t sc = xb->scales[il];
q = q + 32*(il/8) + 16*(il&1);
il = (il/2)%4;
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
template <typename type4x4>
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
const half d_all = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs;
device const uint8_t * h = (device const uint8_t *)xb->hmask;
device const int8_t * scales = (device const int8_t *)xb->scales;
q = q + 32 * (il/8) + 16 * (il&1);
h = h + 16 * (il&1);
uint8_t m = 1 << (il/2);
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
((il/4)>0 ? 12 : 3);
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
const float ml = 4.f * dl;
il = (il/2) & 3;
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl *= coef;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
}
}
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
}
template <typename type4x4>
void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
device const uchar * q = xb->qs;
short is = (il/4) * 2;
q = q + (il/4) * 32 + 16 * (il&1);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float d = il < 2 ? xb->d : xb->d / 16.h;
const float min = xb->dmin;
const float dl = d * sc[0];
const float ml = min * sc[1];
const ushort mask = il < 2 ? 0x0F : 0xF0;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
template <typename type4x4>
void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
device const uint8_t * q = xb->qs;
device const uint8_t * qh = xb->qh;
short is = (il/4) * 2;
q = q + 32 * (il/4) + 16 * (il&1);
qh = qh + 16 * (il&1);
uint8_t ul = 1 << (il/2);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float d = il < 2 ? xb->d : xb->d / 16.f;
const float min = xb->dmin;
const float dl = d * sc[0];
const float ml = min * sc[1];
const ushort mask = il<2 ? 0x0F : 0xF0;
const float qh_val = il<2 ? 16.f : 256.f;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
}
}
template <typename type4x4>
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
const half d_all = xb->d;
device const uint16_t * ql = (device const uint16_t *)xb->ql;
device const uint16_t * qh = (device const uint16_t *)xb->qh;
device const int8_t * scales = (device const int8_t *)xb->scales;
ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
qh = qh + 16*(il/8) + 8*(il&1);
float sc = scales[(il%2) + 2 * ((il/2))];
il = (il/2) & 3;
const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
const float ml = d_all * sc * 32.f;
const float dl0 = d_all * sc;
const float dl1 = dl0 / 256.f;
const float dl2 = dl0 / (256.f * 256.f);
const float dl3 = dl0 / (256.f * 256.f * 256.f);
const uint8_t shr_h = il>2 ? 2 : 0;
const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
const uint8_t shr_l = il>1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
}
}
template <typename type4x4>
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
device const uint16_t * q2 = xb->qs + 4*ib32;
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint16_t * q2 = xb->qs + 4*ib32;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * q3 = xb->qs + 8*ib32;
device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
}
grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * qs = xb->qs + 8*ib32;
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
}
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
}
}
template <typename type4x4>
void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * signs = qs + QK_K/8;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
for (int i = 0; i < 8; ++i) {
reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
}
}
template <typename type4x4>
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
const float d = xb->d;
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint16_t * qh = xb->qh;
const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
const uint16_t h = qh[ib32] >> 6*il;
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * (grid1[i] & 0xf) + ml;
reg[1][i] = dl * (grid1[i] >> 4) + ml;
reg[2][i] = dl * (grid2[i] & 0xf) + ml;
reg[3][i] = dl * (grid2[i] >> 4) + ml;
}
}
template <typename type4x4>
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
device const uint16_t * sc = (device const uint16_t *)xb->scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const float d = scale.f16;
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * qh = xb->qh + 2*ib32 + il;
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
}
}
template <typename type4x4>
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
const float d = xb->d;
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
for (int i = 0; i < 4; ++i) {
aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
}
template <typename type4>
void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
const float d = xb->d;
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
reg[0] = d * kvalues_iq4nl_f[q8[0]];
reg[1] = d * kvalues_iq4nl_f[q8[1]];
reg[2] = d * kvalues_iq4nl_f[q8[2]];
reg[3] = d * kvalues_iq4nl_f[q8[3]];
}
template <typename type4x4>
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
const float d = (float)xb->d * (ls - 32);
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
for (int i = 0; i < 4; ++i) {
aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
}
File diff suppressed because it is too large Load Diff
@@ -1,250 +0,0 @@
#include "common.h"
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
#if 1
template<short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
#define K FC_gated_delta_net_K
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B (n_seqs)
const uint i21 = tgpig.y; // H (head)
const uint i20 = tgpig.x*NSG + ty; // row within S_v
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
// input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
device const float * s_ptr = (device const float *) (s) + state_in_base;
float ls[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] = s_ptr[is];
}
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
// output state base offset: after attention scores
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
// output state per-slot size: S_v * S_v * H * n_seqs
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
// per-(seq,head) offset within a slot
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
for (short t = 0; t < args.ne22; t++) {
float s_k = 0.0f;
if (G == 1) {
const float g_exp = exp(g_ptr[0]);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= g_exp;
s_k += ls[j]*k_ptr[is];
}
} else {
// KDA
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= exp(g_ptr[is]);
s_k += ls[j]*k_ptr[is];
}
}
s_k = simd_sum(s_k);
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
float y = 0.0f;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] += k_ptr[is]*d;
y += ls[j]*q_ptr[is];
}
y = simd_sum(y);
if (tx == 0) {
dst_attn[t*args.ne21*S_v] = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
if (K > 1) {
const int target_slot = (int)args.ne22 - 1 - (int)t;
if (target_slot >= 0 && target_slot < (int)K) {
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
}
}
if (K == 1) {
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
#undef S_v
#undef G
#undef K
}
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
#else
// a simplified version of the above
// no performance improvement, so keep the above version for now
template<typename T, short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B
const uint i21 = tgpig.y; // H
const uint i20 = tgpig.x*NSG + ty;
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
float lsf[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
lsf[j] = s_ptr[is*S_v];
}
thread T * ls = (thread T *) (lsf);
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
for (short t = 0; t < args.ne22; t++) {
device const T * qt_ptr = (device const T *) (q_ptr);
device const T * kt_ptr = (device const T *) (k_ptr);
device const T * gt_ptr = (device const T *) (g_ptr);
if (G == 1) {
*ls *= exp(g_ptr[0]);
} else {
// KDA
*ls *= exp(gt_ptr[tx]);
}
const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
*ls += kt_ptr[tx]*d;
const float y = simd_sum(dot(*ls, qt_ptr[tx]));
if (tx == 0) {
*dst_attn = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
dst_attn += args.ne21*S_v;
}
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
device T * dstt_state = (device T *) (dst_state);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is*S_v] = lsf[j];
}
#undef S_v
#undef G
}
typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
#endif
-347
View File
@@ -1,347 +0,0 @@
#include "common.h"
kernel void kernel_argmax_f32(
constant ggml_metal_kargs_argmax & args,
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);
float lmax = -INFINITY;
int32_t larg = -1;
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
if (x_row[i00] > lmax) {
lmax = x_row[i00];
larg = i00;
}
}
// find the argmax value in the block
float max_val = simd_max(lmax);
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
device int32_t * dst_i32 = (device int32_t *) dst;
threadgroup float * shared_maxval = (threadgroup float *) shmem;
threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
shared_maxval[tiisg] = -INFINITY;
shared_argmax[tiisg] = -1;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shared_maxval[sgitg] = max_val;
shared_argmax[sgitg] = arg_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = shared_maxval[tiisg];
arg_val = shared_argmax[tiisg];
float max_val_reduced = simd_max(max_val);
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
dst_i32[tgpig] = arg_val_reduced;
return;
}
dst_i32[tgpig] = arg_val;
}
kernel void kernel_diag_f32(
constant ggml_metal_kargs_diag & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]]) {
constexpr short NW = N_SIMDWIDTH;
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t i1 = tgpig.x;
device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
}
}
kernel void kernel_roll_f32(
constant ggml_metal_kargs_roll & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
device const float * src0_ptr = (device const float *) src0;
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// apply shifts and wrap around
int64_t i00 = i0 - args.s0;
int64_t i01 = i1 - args.s1;
int64_t i02 = i2 - args.s2;
int64_t i03 = i3 - args.s3;
if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
dst_ptr[dst_idx] = src0_ptr[src_idx];
}
}
template <typename T>
kernel void kernel_pad_impl(
constant ggml_metal_kargs_pad & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t k0 = tgpig.x/args.ne1;
const int32_t i1 = tgpig.x - k0*args.ne1;
const int32_t i03 = i3;
const int32_t i02 = i2;
const int32_t i01 = i1;
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
const int32_t i0 = k0*1024 + tpitg.x + l0;
if (i0 >= args.ne0) {
break;
}
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
dst_ptr[i0] = src0_ptr[i0];
} else {
dst_ptr[i0] = 0.0f;
}
}
}
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
// TODO: this is slow - optimize
kernel void kernel_pad_reflect_1d_f32(
constant ggml_metal_kargs_pad_reflect_1d & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3;
const int64_t i02 = i2;
const int64_t i01 = i1;
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
if (i0 < args.p0) {
dst_ptr[i0] = src0_ptr[args.p0 - i0];
} else if (i0 < args.ne0 - args.p1) {
dst_ptr[i0] = src0_ptr[i0 - args.p0];
} else {
dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
}
}
}
}
kernel void kernel_arange_f32(
constant ggml_metal_kargs_arange & args,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_ptr[i0] = args.start + args.step * i0;
}
}
kernel void kernel_timestep_embedding_f32(
constant ggml_metal_kargs_timestep_embedding & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
int i = tgpig.x;
device float * embed_data = (device float *)(dst + i*args.nb1);
int half_ = args.dim / 2;
for (int j = tpitg.x; j < half_; j += ntg.x) {
float timestep = ((device float *)src0)[i];
float freq = (float)exp(-log((float)args.max_period) * j / half_);
float arg = timestep * freq;
embed_data[j ] = cos(arg);
embed_data[j + half_] = sin(arg);
}
if (args.dim % 2 != 0 && tpitg.x == 0) {
embed_data[2 * half_] = 0.f;
}
}
kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];
const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
g_m[gid] = gmi;
g_v[gid] = gvi;
const float mh = gmi * beta1h;
const float vh = sqrt(gvi * beta2h) + eps;
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
kernel void kernel_opt_step_sgd_f32(
constant ggml_metal_kargs_opt_step_sgd & args,
device float * x,
device const float * g,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}
template<typename T>
kernel void kernel_memset(
constant ggml_metal_kargs_memset & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
template<typename T>
kernel void kernel_count_equal(
constant ggml_metal_kargs_count_equal & args,
device const char * src0,
device const char * src1,
device atomic_int * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const short NSG = FC_count_equal_nsg;
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
int sum = 0;
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
const T v0 = *(device const T *)(base0 + i0*args.nb00);
const T v1 = *(device const T *)(base1 + i0*args.nb10);
sum += (v0 == v1);
}
sum = simd_sum(sum);
if (tiisg == 0) {
shmem_i32[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
float v = 0.0f;
if (tpitg.x < NSG) {
v = shmem_i32[tpitg.x];
}
float total = simd_sum(v);
if (tpitg.x == 0) {
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
}
}
}
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;
-838
View File
@@ -1,838 +0,0 @@
#include "common.h"
#include "dequantize.h"
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
// each block_q contains 16*nl weights
#ifdef GGML_METAL_HAS_TENSOR
template<
typename SA, typename SA_4x4, typename SA_8x8,
typename SB, typename SB_2x4, typename SB_8x8,
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * srcA,
device const char * srcB,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig [[threadgroup_position_in_grid]],
ushort tiitg [[thread_index_in_threadgroup]],
ushort sgitg [[simdgroup_index_in_threadgroup]]) {
(void) sgitg;
// Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
const int K = args.ne00;
const int M = args.ne0;
const int N = args.ne1;
// Batch dimension handling
const int im = tgpig.z;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
// Batch offsets for srcA and srcB
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
// Tile dimensions
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
// Tile offsets in output matrix
const int ra = tgpig.y * NRA;
const int rb = tgpig.x * NRB;
// Threadgroup memory for dequantized A tile only
threadgroup SA * sa = (threadgroup SA *)(shmem);
// Work-item count for A loading
constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
// tA wraps threadgroup memory
auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
// tB wraps device memory directly
device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
const int strideB = args.nb11 / sizeof(T1);
auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
// Configure matmul operation
mpp::tensor_ops::matmul2d<
mpp::tensor_ops::matmul2d_descriptor(
NRB, NRA, N_MM_NK_TOTAL, false, true, true,
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
// Accumulate partial results over K dimension
for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
// === PHASE 1: Dequantization of A into threadgroup memory ===
for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
const int row = work / N_MM_NK;
const int k_chunk = work % N_MM_NK;
const int k_pos = loop_k + k_chunk * 16;
const short k_base = k_chunk * 16;
// Bounds check: skip device read if row is out of matrix bounds
if (ra + row < M) {
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
// Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
// MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
// nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
// Mirrors the legacy kernel's existing guard.
device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
FOR_UNROLL (short i = 0; i < 16; i++) {
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
}
} else {
const int block_idx = k_pos / (16 * nl);
const short il = (k_pos / 16) % nl;
device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
SA_4x4 temp_a;
dequantize_func(row_ptr + block_idx, il, temp_a);
FOR_UNROLL (short i = 0; i < 16; i++) {
// Zero-pad A for K positions beyond valid range (handles partial K iterations)
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
}
}
} else {
// Zero-pad rows beyond matrix bounds
FOR_UNROLL (short i = 0; i < 16; i++) {
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// === PHASE 2: Tensor matmul ===
auto mA = tA.slice(0, 0);
auto mB = tB.slice(loop_k, rb);
mm.run(mB, mA, cT);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Store result tile to output matrix (with batch offset)
// cT.store handles bounds checking via tD's extents (M, N)
device float * dstBatch = (device float *)dst + im * N * M;
auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
cT.store(tD.slice(ra, rb));
}
#else
template<
typename S0, typename S0_4x4, typename S0_8x8,
typename S1, typename S1_2x4, typename S1_8x8,
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
constexpr int NR0 = 64;
constexpr int NR1 = 32;
constexpr int NK = 32;
constexpr int NL0 = NK/16;
constexpr int NL1 = NK/8;
const int im = tgpig.z;
const int r0 = tgpig.y*NR0;
const int r1 = tgpig.x*NR1;
// if this block is of 64x32 shape or smaller
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
// a thread shouldn't load data outside of the matrix
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
const short il0 = (tiitg % NL0);
short il = il0;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
const short iy = 8*(tiitg % NL1);
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*(r1 + lr1)
+ args.nb10*iy);
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
// NOTE: this is massively slower.. WTF?
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
const short ib = 4*sx + sy;
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
const short ib = 4*sx + sy;
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
}
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
y += NK;
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8*64;
lsmb += 4*64;
}
}
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
// if no bounds checks on the output are needed, we can directly write to device memory
device float * C = (device float *) dst +
(r0 + 32*(sgitg & 1)) + \
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
for (int j = tiitg; j < nr1; j += NR1) {
device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*NR0);
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = 0;
for (; i < nr0/4; i++) {
*(D4 + i) = *(C4 + i);
}
i *= 4;
for (; i < nr0; i++) {
*(D + i) = *(C + i);
}
}
}
}
}
#endif // GGML_METAL_HAS_TENSOR
template<short ne20> // n_expert_used
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
device const char * src2,
device char * htpe,
device char * hids,
threadgroup char * shmem [[threadgroup(0)]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const short ide = tpitg; // expert id
uint32_t n_all = 0;
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
if (i21 + tpitg < args.ne21) {
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sids[i20] = src2_i32[i20];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short t = 0; t < ntg; t++) {
if (i21 + t >= args.ne21) {
break;
}
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
short sel = 0;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sel += (sids[i20] == ide)*(i20 + 1);
}
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
n_all += sel > 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
tpe_u32[ide] = n_all;
}
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
device const char * htpe,
device const char * hids,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
constexpr int NK = 32;
constexpr int NL0 = NK/16;
constexpr int NL1 = NK/8;
const int im = tgpig.z; // expert
const int r0 = tgpig.y*NR0;
const int r1 = tgpig.x*NR1;
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
device const int32_t * ids_i32 = (device const int32_t *) (hids);
const int32_t neh1 = tpe_u32[im];
if (r1 >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
// a thread shouldn't load data outside of the matrix
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
const short il0 = (tiitg % NL0);
short il = il0;
const int id = ids_i32[im*args.ne21 + r1 + lr1];
const short i11 = (id % args.ne20) % args.ne11;
const short i12 = (id / args.ne20);
const short i13 = 0;
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
const short iy = 8*(tiitg % NL1);
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*i11
+ args.nb10*iy);
#ifndef GGML_METAL_HAS_TENSOR
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
#else
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
mpp::tensor_ops::matmul2d<
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
execution_simdgroups<4>> mm;
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
#endif
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
#ifndef GGML_METAL_HAS_TENSOR
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
// NOTE: this is massively slower.. WTF?
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
const short ib = 4*sx + sy;
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
const short ib = 4*sx + sy;
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
}
#else
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
const short lx = i%8;
const short ly = (tiitg/NL0)%8;
//const short lx = (tiitg/NL0)%8;
//const short ly = i%8;
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
const short lx = i%8;
const short ly = (tiitg/NL0)%8;
//const short lx = (tiitg/NL0)%8;
//const short ly = i%8;
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
}
#endif
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
y += NK;
threadgroup_barrier(mem_flags::mem_threadgroup);
#ifndef GGML_METAL_HAS_TENSOR
// load matrices from threadgroup memory and conduct outer products
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8*64;
lsmb += 4*64;
}
#else
auto sA = tA.slice(0, 0);
auto sB = tB.slice(0, 0);
mm.run(sB, sA, cT);
#endif
}
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
#ifdef GGML_METAL_HAS_TENSOR
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
cT.store(tC);
#else
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
}
#endif
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short j = sgitg; j < nr1; j += 4) {
const int id = ids_i32[im*args.ne21 + r1 + j];
const short ide = id % args.ne20;
const short idt = id / args.ne20;
device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = (threadgroup float *) shmem + j*NR0;
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = tiisg;
for (; i < nr0/4; i += 32) {
*(D4 + i) = *(C4 + i);
}
i = (4*(nr0/4)) + tiisg;
for (; i < nr0; i += 32) {
*(D + i) = *(C + i);
}
}
}
//
// matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
//
// indirect matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
File diff suppressed because it is too large Load Diff
-308
View File
@@ -1,308 +0,0 @@
#include "common.h"
// F == 1 : norm (no fuse)
// F == 2 : norm + mul
// F == 3 : norm + mul + add
template <typename T, short F>
kernel void kernel_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
T sumft(0.0f);
float sumf = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumft += x[i00];
}
sumf = dot(sumft, T(1.0f));
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
sumf = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
y[i00] = x[i00] - mean;
sumf += dot(y[i00], y[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float variance = sumf/args.ne00;
const float scale = 1.0f/sqrt(variance + args.eps);
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (y[i00]*scale);
}
if (F == 2) {
y[i00] = (y[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
// F == 1 : rms_norm (no fuse)
// F == 2 : rms_norm + mul
// F == 3 : rms_norm + mul + add
template <typename T, short F>
kernel void kernel_rms_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
const float scale = 1.0f/sqrt(mean + args.eps);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (x[i00]*scale);
}
if (F == 2) {
y[i00] = (x[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float scale = 1.0f/max(sqrt(sumf), args.eps);
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
device float * dst,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t ne = args.ne00*args.ne01*args.ne02;
const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);
int start = tgpig * gs;
int end = start + gs;
start += tpitg;
if (end >= ne) {
end = ne;
}
float tmp = 0.0f; // partial sum for thread in warp
for (int j = start; j < end; j += ntg) {
tmp += src0[j];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = simd_sum(tmp);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = tmp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = buf[tiisg];
tmp = simd_sum(tmp);
}
const float mean = tmp / gs;
tmp = 0.0f;
for (int j = start; j < end; j += ntg) {
float xi = src0[j] - mean;
dst[j] = xi;
tmp += xi * xi;
}
tmp = simd_sum(tmp);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = tmp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = buf[tiisg];
tmp = simd_sum(tmp);
}
const float variance = tmp / gs;
const float scale = 1.0f/sqrt(variance + args.eps);
for (int j = start; j < end; j += ntg) {
dst[j] *= scale;
}
}
-148
View File
@@ -1,148 +0,0 @@
#include "common.h"
kernel void kernel_pool_2d_max_f32(
constant ggml_metal_kargs_pool_2d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const int idx = gid;
const int I_HW = args.IH * args.IW;
const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / args.OW;
const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
const int eh = MIN(args.IH, start_h + args.k1);
const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
const int ew = MIN(args.IW, start_w + args.k0);
float res = -INFINITY;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
res = MAX(res, i_ptr[i * args.IW + j]);
}
}
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_2d_avg_f32(
constant ggml_metal_kargs_pool_2d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const int idx = gid;
const int I_HW = args.IH * args.IW;
const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / args.OW;
const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
const int eh = MIN(args.IH, start_h + args.k1);
const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
const int ew = MIN(args.IW, start_w + args.k0);
// const float scale = 1. / ((eh - bh) * (ew - bw));
const float scale = 1. / (args.k0 * args.k1);
float res = 0;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
float cur = i_ptr[i * args.IW + j];
res += cur * scale;
}
}
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_1d_max_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = -INFINITY;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
int j = base + ki;
if (j < 0 || j >= args.IW){
continue;
}
float v = src[src_off + j];
acc = max(acc, v);
}
dst[dst_off + ow] = acc;
}
kernel void kernel_pool_1d_avg_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = 0.0f;
int cnt = 0;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
const int j = base + ki;
if (j < 0 || j >= args.IW) {
continue;
}
acc += src[src_off + j];
cnt += 1;
}
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
}
-213
View File
@@ -1,213 +0,0 @@
#pragma once
#include "common.h"
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
float sum_abs = 0.0f;
for (int j = 0; j < QK1_0; j++) {
sum_abs += fabs(src[j]);
}
dst.d = sum_abs / QK1_0;
for (int j = 0; j < QK1_0 / 8; j++) {
dst.qs[j] = 0;
}
for (int j = 0; j < QK1_0; j++) {
if (src[j] >= 0.0f) {
dst.qs[j / 8] |= (1 << (j % 8));
}
}
}
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_0/2 + j]*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
dst.qs[j] = xi0;
dst.qs[j] |= xi1 << 4;
}
}
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
#pragma METAL fp math_mode(safe)
float min = FLT_MAX;
float max = -FLT_MAX;
for (int j = 0; j < QK4_1; j++) {
const float v = src[j];
if (min > v) min = v;
if (max < v) max = v;
}
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
dst.m = min;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK4_1/2 + j] - min)*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
dst.qs[j] = xi0;
dst.qs[j] |= xi1 << 4;
}
}
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK5_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -16;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK5_0/2 + j]*id;
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst.qh[j] = qh8[j];
}
}
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
#pragma METAL fp math_mode(safe)
float max = src[0];
float min = src[0];
for (int j = 1; j < QK5_1; j++) {
const float v = src[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
dst.m = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst.qh[j] = qh8[j];
}
}
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = src[j];
amax = MAX(amax, fabs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = src[j]*id;
dst.qs[j] = round(x0);
}
}
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_NL; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / kvalues_iq4nl_f[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
dst.qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl_f[xi0];
const float v1 = kvalues_iq4nl_f[xi1];
const float w0 = src[0 + j]*src[0 + j];
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
}
-389
View File
@@ -1,389 +0,0 @@
#include "common.h"
#include "dequantize.h"
#include "quantize.h"
template<typename T0, typename T1>
kernel void kernel_cpy_t_t(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
break;
}
}
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
#endif
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
#endif
template<short QK,
typename block_q,
void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_cpy_f32_q(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
quantize_func(src, dst_data[i00]);
break;
}
}
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_cpy_q_f32(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
T4x4 temp;
dequantize_func(src_data + i00/nl, i00%nl, temp);
dst_data[i00] = temp;
break;
}
}
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
template<typename T>
kernel void kernel_concat(
constant ggml_metal_kargs_concat & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
if (i1 >= args.ne1) {
return;
}
int o[4] = {0, 0, 0, 0};
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
device const T * x;
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
x = (device const T *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
} else {
x = (device const T *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
}
device T * y = (device T *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
*y = *x;
}
}
typedef decltype(kernel_concat<float>) kernel_concat_t;
template [[host_name("kernel_concat_f32")]] kernel kernel_concat_t kernel_concat<float>;
template [[host_name("kernel_concat_f16")]] kernel kernel_concat_t kernel_concat<half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_concat_bf16")]] kernel kernel_concat_t kernel_concat<bfloat>;
#endif
template [[host_name("kernel_concat_i8")]] kernel kernel_concat_t kernel_concat<char>;
template [[host_name("kernel_concat_i16")]] kernel kernel_concat_t kernel_concat<short>;
template [[host_name("kernel_concat_i32")]] kernel kernel_concat_t kernel_concat<int>;
template [[host_name("kernel_concat_i64")]] kernel kernel_concat_t kernel_concat<long>;
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t i02 = i11;
const int32_t i03 = i12;
auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
float4x4 temp;
dequantize_func(psrc + ind/nl, ind%nl, temp);
pdst[ind] = temp;
break;
}
}
template<typename T0, typename T>
kernel void kernel_get_rows_f(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t i02 = i11;
const int32_t i03 = i12;
auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
pdst[ind] = psrc[ind];
break;
}
}
template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_set_rows_q32(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i12 = i03%args.ne12;
const int32_t i11 = i02%args.ne11;
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int32_t i10 = i01;
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
quantize_func(src_row + 32*ind, dst_row[ind]);
}
}
template<typename T, typename TI>
kernel void kernel_set_rows_f(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i12 = i03%args.ne12;
const int32_t i11 = i02%args.ne11;
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int32_t i10 = i01;
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
dst_row[ind] = (T) src_row[ind];
}
}
//
// get rows
//
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
#endif
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
//
// set rows
//
typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
#endif
typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
-228
View File
@@ -1,228 +0,0 @@
#include "common.h"
kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (args.np == 0) {
return;
}
// TODO: become function constant
const uint nsg = (ntg.x + 31) / 32;
float sumf = 0;
for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
sumf += src0[i0];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float total = 0;
if (sgitg == 0) {
float v = 0;
if (tpitg.x < nsg) {
v = shmem_f32[tpitg.x];
}
total = simd_sum(v);
if (tpitg.x == 0) {
dst[0] = total;
}
}
}
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
template <typename T0, typename T>
kernel void kernel_sum_rows_impl(
constant ggml_metal_kargs_sum_rows & args,
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_sum_rows_op
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
if (sgitg == 0) {
shmem_t[tiisg] = 0.0f;
}
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
T0 sumf = T0(0.0f);
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_t[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_t[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
if (is_same<float4, T0>::value) {
dst_row[0] = sum(sumf) / (4*args.ne00);
} else {
dst_row[0] = sum(sumf) / args.ne00;
}
} else {
dst_row[0] = sum(sumf);
}
}
#undef FC_OP
}
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
template<typename T>
kernel void kernel_cumsum_blk(
constant ggml_metal_kargs_cumsum_blk & args,
device const char * src0,
device char * tmp,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 +
args.nb01*i01 +
args.nb02*i02 +
args.nb03*i03);
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
float v = 0.0f;
if (i00 + tpitg.x < args.ne00) {
v = src0_row[i00 + tpitg.x];
}
float s = simd_prefix_inclusive_sum(v);
if (tiisg == N_SIMDWIDTH - 1) {
shmem_f32[sgitg] = s;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
s += shmem_f32[sgitg];
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] = s;
}
if (args.outb && tpitg.x == ntg.x - 1) {
device float * tmp_row = (device float *) tmp +
args.net0*i01 +
args.net0*args.net1*i02 +
args.net0*args.net1*args.net2*i03;
tmp_row[ib] = s;
}
}
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
template<typename T>
kernel void kernel_cumsum_add(
constant ggml_metal_kargs_cumsum_add & args,
device const char * tmp,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
if (ib == 0) {
return;
}
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * tmp_row = (device const float *) (tmp +
args.nbt1*i01 +
args.nbt2*i02 +
args.nbt3*i03);
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
}
}
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
-318
View File
@@ -1,318 +0,0 @@
#include "common.h"
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
constant bool FC_rope_is_back [[function_constant(FC_ROPE + 1)]];
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
thread float * cos_theta, thread float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
*cos_theta = cos(theta) * mscale;
*sin_theta = sin(theta) * mscale;
if (FC_rope_is_back) {
*sin_theta *= -1.0f;
}
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
}
static void rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
}
template<typename T>
kernel void kernel_rope_norm(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_neox(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_multi(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
// mrope theta calculations
// note: the rest is the same as kernel_rope_neox
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
if (FC_rope_is_imrope) {
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
theta_base = (float) pos[i2 + args.ne02 * 2];
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
theta_base = (float) pos[i2 + args.ne02 * 0];
} else { // e
theta_base = (float) pos[i2 + args.ne02 * 3];
}
} else {
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
}
// end of mrope
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_vision(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
const int ic = i0/2;
// mrope theta calculations (only support 2 dimensions)
const int sect_dims = args.sect_0 + args.sect_1;
const int sector = ic % sect_dims;
float p;
float theta_base;
if (sector < args.sect_1) {
p = (float) sector;
theta_base = (float) pos[i2];
} else {
p = (float) sector - args.sect_0;
theta_base = (float) pos[i2 + args.ne02];
}
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
-223
View File
@@ -1,223 +0,0 @@
#include "common.h"
template<typename T>
kernel void kernel_soft_max(
constant ggml_metal_kargs_soft_max & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x;
const int32_t i13 = i03%args.ne13;
const int32_t i12 = i02%args.ne12;
const int32_t i11 = i01;
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
// ALiBi
if (args.max_bias > 0.0f) {
const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
}
// find the max value in the block
float max_val = simd_max(lmax);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
}
// This barrier fixes a failing test
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
pdst[i00] *= inv_sum;
}
}
template<typename T>
kernel void kernel_soft_max_4(
constant ggml_metal_kargs_soft_max & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x;
const int32_t i13 = i03%args.ne13;
const int32_t i12 = i02%args.ne12;
const int32_t i11 = i01;
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
if (args.max_bias > 0.0f) {
const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max_val = simd_max(lmax);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
// This barrier fixes a failing test
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
pdst4[i00] *= inv_sum;
}
}
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
@@ -1,75 +0,0 @@
#include "common.h"
constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
kernel void kernel_solve_tri_f32(
constant ggml_metal_kargs_solve_tri & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
ushort3 tgpig[[threadgroup_position_in_grid]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
constexpr short NW = N_SIMDWIDTH;
const short NSG = FC_solve_tri_nsg;
const short N = FC_solve_tri_n;
const short K = FC_solve_tri_k;
const short NP = PAD2(N, NW);
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x*NSG + sgitg;
threadgroup float * sh0 = (threadgroup float *) shmem;
device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
for (short rr = 0; rr < N; rr += NSG) {
threadgroup_barrier(mem_flags::mem_threadgroup);
{
threadgroup float * sh0_cur = sh0 + sgitg*NP;
for (short t = 0; t*NW < N; ++t) {
const short idx = t*NW + tiisg;
sh0_cur[idx] = src0_ptr[idx];
}
src0_ptr += NSG*N;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (i01 >= args.ne10) {
continue;
}
for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
const short r = rr + ir;
threadgroup float * sh0_cur = sh0 + ir*NP;
float sum = 0.0f;
for (short t = 0; t*NW < r; ++t) {
const short idx = t*NW + tiisg;
sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
}
sum = simd_sum(sum);
if (tiisg == 0) {
const float diag = sh0_cur[r];
dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
}
}
}
}
-279
View File
@@ -1,279 +0,0 @@
#include "common.h"
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
kernel void kernel_ssm_conv_f32_f32(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
const int64_t nc = args.ne10;
//const int64_t ncs = args.ne00;
//const int64_t nr = args.ne01;
//const int64_t n_t = args.ne1;
//const int64_t n_s = args.ne2;
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}
x[0] = sumf;
}
kernel void kernel_ssm_conv_f32_f32_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
const int64_t nc = args.ne10;
//const int64_t ncs = args.ne00;
//const int64_t nr = args.ne01;
//const int64_t n_t = args.ne1;
//const int64_t n_s = args.ne2;
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
x[0] = sumf;
}
constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
// Batched version: each threadgroup processes multiple tokens for better efficiency
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
kernel void kernel_ssm_conv_f32_f32_batched(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// tgpig.x = row index (ir)
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
// tgpig.z = sequence index (i3)
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
const short BATCH_SIZE = FC_ssm_conv_bs;
const int64_t ir = tgpig.x;
const int64_t i2_base = tgpig.y * BATCH_SIZE;
const int64_t i3 = tgpig.z;
const int64_t i2_off = tpitg.x;
const int64_t i2 = i2_base + i2_off;
const int64_t nc = args.ne10; // conv kernel size (typically 4)
const int64_t n_t = args.ne1; // number of tokens
// Bounds check for partial batches at the end
if (i2 >= n_t) {
return;
}
// Load conv weights (shared across all tokens for this row)
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
// Load source for this specific token
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
// Output location for this token
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}
x[0] = sumf;
}
kernel void kernel_ssm_conv_f32_f32_batched_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// tgpig.x = row index (ir)
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
// tgpig.z = sequence index (i3)
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
const short BATCH_SIZE = FC_ssm_conv_bs;
const int64_t ir = tgpig.x;
const int64_t i2_base = tgpig.y * BATCH_SIZE;
const int64_t i3 = tgpig.z;
const int64_t i2_off = tpitg.x;
const int64_t i2 = i2_base + i2_off;
const int64_t nc = args.ne10; // conv kernel size (typically 4)
const int64_t n_t = args.ne1; // number of tokens
// Bounds check for partial batches at the end
if (i2 >= n_t) {
return;
}
// Load conv weights (shared across all tokens for this row)
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
// Load source for this specific token
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
// Output location for this token
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
x[0] = sumf;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// Optimized version: reduces redundant memory loads by having one thread load shared values
kernel void kernel_ssm_scan_f32(
constant ggml_metal_kargs_ssm_scan & args,
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
constexpr short NW = N_SIMDWIDTH;
// Shared memory layout:
// [0..sgptg*NW-1]: partial sums for reduction (existing)
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
threadgroup float * shared_sums = shared;
threadgroup float * shared_x_dt = shared + sgptg * NW;
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
shared_sums[tpitg.x] = 0.0f;
const int32_t i0 = tpitg.x;
const int32_t i1 = tgpig.x;
const int32_t ir = tgpig.y; // current head
const int32_t i3 = tgpig.z; // current seq
const int32_t nc = args.d_state;
const int32_t nr = args.d_inner;
const int32_t nh = args.n_head;
const int32_t ng = args.n_group;
const int32_t n_t = args.n_seq_tokens;
const int32_t s_off = args.s_off;
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int32_t i = i0 + i1*nc;
const int32_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = 0.0f;
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
const float A0 = A[i0%args.ne30];
device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Pre-compute x_dt and dA for this batch of tokens
// Only first sgptg threads do the loads and expensive math
if (i0 < sgptg && i2 + i0 < n_t) {
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
device const float * x_t = x + i0 * args.ns12;
device const float * dt_t = dt + i0 * args.ns21;
const float dt0 = dt_t[0];
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
shared_x_dt[i0] = x_t[0] * dtsp;
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
const float x_dt = shared_x_dt[t];
const float dA = exp(shared_dA[t] * A0);
s = (s0 * dA) + (B[i0] * x_dt);
const float sumf = simd_sum(s * C[i0]);
if (tiisg == 0) {
shared_sums[t*NW + sgitg] = sumf;
}
// recurse
s0 = s;
B += args.ns42;
C += args.ns52;
}
// Advance pointers for next batch
x += sgptg * args.ns12;
dt += sgptg * args.ns21;
threadgroup_barrier(mem_flags::mem_threadgroup);
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
if (tiisg == 0 && i2 + sgitg < n_t) {
y[sgitg*nh*nr] = sumf;
}
y += sgptg*nh*nr;
}
s_buff[i] = s;
}
-69
View File
@@ -1,69 +0,0 @@
#include "common.h"
template<uint32_t ttype>
bool _ggml_vec_tri_cmp(const int i, const int r);
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
return i < r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
return i <= r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
return i > r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
return i >= r;
}
template<typename T, int ttype>
kernel void kernel_tri(
constant ggml_metal_kargs_tri & args,
device const char * src0,
device const char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
// Each thread is a single element of the row if ne00 < max threads per
// threadgroup, so this will loop once for each index that this thread is
// responsible for
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
// Use the comparison as a mask for branchless
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
}
}
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
#endif
-360
View File
@@ -1,360 +0,0 @@
#include "common.h"
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
template <typename T0, typename T, typename TC>
kernel void kernel_unary_impl(
constant ggml_metal_kargs_unary & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_unary_op
#define FC_CNT FC_unary_cnt
device const T0 * src0_ptr;
device T * dst_ptr;
int i0;
if (FC_CNT) {
i0 = tgpig.x;
src0_ptr = (device const T0 *) (src0);
dst_ptr = (device T *) (dst);
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int k0 = tgpig.x/args.ne01;
const int i01 = tgpig.x - k0*args.ne01;
i0 = k0*ntg.x + tpitg.x;
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
}
{
//threadgroup_barrier(mem_flags::mem_none);
if (!FC_CNT) {
if (i0 >= args.ne0) {
return;
}
}
const TC x = (TC) src0_ptr[i0];
if (FC_OP == OP_UNARY_NUM_SCALE) {
dst_ptr[i0] = (T) (args.scale * x + args.bias);
}
if (FC_OP == OP_UNARY_NUM_FILL) {
dst_ptr[i0] = (T) args.val;
}
if (FC_OP == OP_UNARY_NUM_CLAMP) {
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
}
if (FC_OP == OP_UNARY_NUM_SQR) {
dst_ptr[i0] = (T) (x * x);
}
if (FC_OP == OP_UNARY_NUM_SQRT) {
dst_ptr[i0] = (T) sqrt(x);
}
if (FC_OP == OP_UNARY_NUM_SIN) {
dst_ptr[i0] = (T) sin(x);
}
if (FC_OP == OP_UNARY_NUM_COS) {
dst_ptr[i0] = (T) cos(x);
}
if (FC_OP == OP_UNARY_NUM_LOG) {
dst_ptr[i0] = (T) log(x);
}
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
}
if (FC_OP == OP_UNARY_NUM_TANH) {
dst_ptr[i0] = (T) precise::tanh(x);
}
if (FC_OP == OP_UNARY_NUM_RELU) {
dst_ptr[i0] = (T) fmax(0, x);
}
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_GELU) {
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
}
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
}
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
}
if (FC_OP == OP_UNARY_NUM_SILU) {
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_ELU) {
dst_ptr[i0] = (T) elu_approx(x);
}
if (FC_OP == OP_UNARY_NUM_NEG) {
dst_ptr[i0] = (T) -x;
}
if (FC_OP == OP_UNARY_NUM_ABS) {
dst_ptr[i0] = (T) fabs(x);
}
if (FC_OP == OP_UNARY_NUM_SGN) {
dst_ptr[i0] = T(x > 0) - T(x < 0);
}
if (FC_OP == OP_UNARY_NUM_STEP) {
dst_ptr[i0] = T(x > 0);
}
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
}
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
}
if (FC_OP == OP_UNARY_NUM_EXP) {
dst_ptr[i0] = (T) exp(x);
}
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
}
if (FC_OP == OP_UNARY_NUM_EXPM1) {
// TODO: precise implementation
dst_ptr[i0] = (T) (exp(x) - 1);
}
if (FC_OP == OP_UNARY_NUM_FLOOR) {
dst_ptr[i0] = (T) floor(x);
}
if (FC_OP == OP_UNARY_NUM_CEIL) {
dst_ptr[i0] = (T) ceil(x);
}
if (FC_OP == OP_UNARY_NUM_ROUND) {
dst_ptr[i0] = (T) round(x);
}
if (FC_OP == OP_UNARY_NUM_TRUNC) {
dst_ptr[i0] = (T) trunc(x);
}
if (FC_OP == OP_UNARY_NUM_XIELU) {
const TC xi = x;
const TC gate = TC(xi > TC(0.0f));
const TC clamped = fmin(xi, TC(args.val));
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
}
}
#undef FC_OP
#undef FC_CNT
}
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
template<typename T>
kernel void kernel_reglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
dst_row[i0] = (T)(x0*x1*(x0 > 0.0f));
}
}
typedef decltype(kernel_reglu<float>) kernel_reglu_t;
template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>;
template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
template<typename T>
kernel void kernel_geglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
dst_row[i0] = (T)(gelu*x1);
}
}
typedef decltype(kernel_geglu<float>) kernel_geglu_t;
template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>;
template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
template<typename T>
kernel void kernel_swiglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float silu = x0 / (1.0f + exp(-x0));
dst_row[i0] = (T)(silu*x1);
}
}
typedef decltype(kernel_swiglu<float>) kernel_swiglu_t;
template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>;
template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
template<typename T>
kernel void kernel_swiglu_oai(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
float x0 = src0_row[i0];
float x1 = src1_row[i0];
x0 = min(x0, args.limit);
x1 = max(min(x1, args.limit), -args.limit);
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
out_glu = out_glu * (1.0f + x1);
dst_row[i0] = (T)out_glu;
}
}
typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t;
template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>;
template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
template<typename T>
kernel void kernel_geglu_erf(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
dst_row[i0] = (T)(gelu_erf*x1);
}
}
typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t;
template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>;
template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
template<typename T>
kernel void kernel_geglu_quick(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = (T)(gelu_quick*x1);
}
}
typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t;
template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>;
template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
-179
View File
@@ -1,179 +0,0 @@
#include "common.h"
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
kernel void kernel_upscale_nearest_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3/args.sf3;
const int64_t i02 = i2/args.sf2;
const int64_t i01 = i1/args.sf1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int64_t i00 = i0/args.sf0;
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_ptr[0] = src0_ptr[0];
}
}
static inline float bilinear_tri(float x) {
return MAX(0.0f, 1.0f - fabs(x));
}
kernel void kernel_upscale_bilinear_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
src0 += i03*args.nb03 + i02*args.nb02;
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (FC_upscale_aa) {
const float support0 = MAX(1.0f, 1.0f / args.sf0);
const float invscale0 = 1.0f / support0;
const float support1 = MAX(1.0f, 1.0f / args.sf1);
const float invscale1 = 1.0f / support1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
float sum = 0.0f;
float wsum = 0.0f;
for (int64_t sy = y_min; sy < y_max; ++sy) {
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
for (int64_t sx = x_min; sx < x_max; ++sx) {
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
const float w = wx * wy;
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
sum += (*src_ptr) * w;
wsum += w;
}
}
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
dst_ptr[i0] = v;
}
} else {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
const float v =
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
(*src10) * fd0 * (1.0f - fd1) +
(*src01) * (1.0f - fd0) * fd1 +
(*src11) * fd0 * fd1;
dst_ptr[i0] = v;
}
}
}
static inline float bicubic_weight1(float x) {
const float a = -0.75f;
return ((a + 2) * x - (a + 3)) * x * x + 1;
}
static inline float bicubic_weight2(float x) {
const float a = -0.75f;
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
}
kernel void kernel_upscale_bicubic_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = (int64_t)floor(f01);
const float fd1 = f01 - (float)i01;
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
const float w_y1 = bicubic_weight1(fd1);
const float w_y2 = bicubic_weight1(1.0f - fd1);
const float w_y3 = bicubic_weight2(2.0f - fd1);
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = (int64_t)floor(f00);
const float fd0 = f00 - (float)i00;
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
const float w_x1 = bicubic_weight1(fd0);
const float w_x2 = bicubic_weight1(1.0f - fd0);
const float w_x3 = bicubic_weight2(2.0f - fd0);
float sum = 0.0f;
for (int dy = -1; dy <= 2; ++dy) {
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
for (int dx = -1; dx <= 2; ++dx) {
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
sum += (*src_ptr) * wx * wy;
}
}
dst_ptr[i0] = sum;
}
}
-179
View File
@@ -1,179 +0,0 @@
#include "common.h"
kernel void kernel_rwkv_wkv6_f32(
device const float * k,
device const float * v,
device const float * r,
device const float * tf,
device const float * td,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _k[head_size];
threadgroup float _r[head_size];
threadgroup float _tf[head_size];
threadgroup float _td[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
_tf[tid] = tf[head_id * head_size + tid];
threadgroup_barrier(mem_flags::mem_threadgroup);
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0;
for (uint j = 0; j < head_size; j += 4) {
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
float4 temp = tf_vec * kv + s_vec;
y += dot(r_vec, temp);
s_vec = s_vec * td_vec + kv;
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid] = state[i];
}
}
kernel void kernel_rwkv_wkv7_f32(
device const float * r,
device const float * w,
device const float * k,
device const float * v,
device const float * a,
device const float * b,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _r[head_size];
threadgroup float _w[head_size];
threadgroup float _k[head_size];
threadgroup float _a[head_size];
threadgroup float _b[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i];
}
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0, sa = 0.0;
float4 sa_vec(0.0);
for (uint j = 0; j < head_size; j += 4) {
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
sa_vec += a_vec * s_vec;
}
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
for (uint j = 0; j < head_size; j += 4) {
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
s_vec = s_vec * w_vec + kv + sa * b_vec;
y += dot(s_vec, r_vec);
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i] = state[i];
}
}
+13 -1
View File
@@ -60,7 +60,19 @@ if (Vulkan_FOUND)
message(STATUS "Vulkan found")
ggml_add_backend_library(ggml-vulkan
ggml-vulkan.cpp
ggml-vulkan-instance.cpp
ggml-vulkan-shaders.cpp
ggml-vulkan-buffers.cpp
ggml-vulkan-pipelines.cpp
ggml-vulkan-matmul.cpp
ggml-vulkan-flash-attn.cpp
ggml-vulkan-operators.cpp
ggml-vulkan-graph.cpp
ggml-vulkan-backend.cpp
ggml-vulkan-debug.cpp
ggml-vulkan-types.h
ggml-vulkan-push-constants.h
ggml-vulkan-common.h
../../include/ggml-vulkan.h
)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,783 @@
#include "ggml-vulkan-common.h"
ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
/* .get_name = */ ggml_backend_vk_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment,
/* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size,
/* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size,
/* .is_host = */ NULL,
};
static std::vector<uint32_t> ggml_vk_find_memory_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
std::vector<uint32_t> indices;
for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
vk::MemoryType memory_type = mem_props->memoryTypes[i];
if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
(flags & memory_type.propertyFlags) == flags &&
mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
indices.push_back(i);
}
}
return indices;
}
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list,
void *import_ptr = nullptr) {
VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
if (size > device->max_buffer_size) {
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
}
vk_buffer buf = std::make_shared<vk_buffer_struct>();
if (size == 0) {
buf->size = 0;
return buf;
}
vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
vk::MemoryAllocateFlags mem_flags {};
if (device->buffer_device_address) {
usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
}
vk::BufferCreateInfo buffer_create_info{
vk::BufferCreateFlags(),
size,
usage_flags,
vk::SharingMode::eExclusive,
0,
nullptr,
};
vk::ExternalMemoryBufferCreateInfo external_memory_bci;
if (import_ptr) {
external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
buffer_create_info.setPNext(&external_memory_bci);
}
buf->buffer = device->device.createBuffer(buffer_create_info);
vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
const vk::MemoryPriorityAllocateInfoEXT mem_priority_info { 1.0f };
vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
if (device->memory_priority) {
mem_flags_info.setPNext(&mem_priority_info);
}
if (import_ptr) {
vk::MemoryHostPointerPropertiesEXT host_pointer_props;
try {
host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr);
} catch (vk::SystemError& e) {
GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what());
device->device.destroyBuffer(buf->buffer);
return {};
}
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
uint32_t memory_type_idx;
vk::MemoryPropertyFlags property_flags = *req_flags_list.begin();
for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
continue;
}
if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) {
continue;
}
vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];
// check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed
if ((memory_type.propertyFlags & property_flags) == property_flags) {
property_flags = memory_type.propertyFlags;
break;
}
}
if (memory_type_idx == 32) {
GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n");
device->device.destroyBuffer(buf->buffer);
return {};
}
buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags;
try {
vk::ImportMemoryHostPointerInfoEXT import_info;
import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
import_info.pHostPointer = import_ptr;
import_info.setPNext(&mem_flags_info);
buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info });
} catch (const vk::SystemError& e) {
}
} else {
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
const auto & req_flags = *it;
const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
if (memory_type_indices.empty()) {
continue;
}
bool done = false;
for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
try {
buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
buf->memory_property_flags = mem_props.memoryTypes[*mtype_it].propertyFlags;
done = true;
break;
} catch (const vk::SystemError& e) {
// loop and retry
// during last attempt throw the exception
if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
device->device.destroyBuffer(buf->buffer);
throw e;
}
}
}
if (done) {
break;
}
}
}
if (!buf->device_memory) {
device->device.destroyBuffer(buf->buffer);
throw vk::OutOfDeviceMemoryError("No suitable memory type found");
}
buf->ptr = nullptr;
if (import_ptr) {
buf->ptr = import_ptr;
} else {
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
}
}
device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
buf->device = device;
buf->size = size;
if (device->buffer_device_address) {
const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
buf->bda_addr = device->device.getBufferAddress(addressInfo);
}
device->memory_logger->log_allocation(buf, size);
return buf;
}
vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags) {
try {
return ggml_vk_create_buffer(device, size, {req_flags, fallback_flags});
} catch (const vk::SystemError& e) {
std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
throw e;
}
}
vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
vk_buffer buf;
try {
if (device->prefer_host_memory) {
buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
vk::MemoryPropertyFlagBits::eDeviceLocal});
} else if (device->uma) {
// On UMA, prefer host-visible memory so direct tensor borrowing works.
// If unavailable, fall back to device-local memory.
buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
vk::MemoryPropertyFlagBits::eDeviceLocal,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
} else if (device->disable_host_visible_vidmem) {
if (device->allow_sysmem_fallback) {
buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
} else {
buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal});
}
} else {
// use rebar if available, otherwise fallback to device only visible memory
if (device->allow_sysmem_fallback) {
buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
vk::MemoryPropertyFlagBits::eDeviceLocal,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
} else {
buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
vk::MemoryPropertyFlagBits::eDeviceLocal});
}
}
} catch (const vk::SystemError& e) {
std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
throw e;
}
return buf;
}
void ggml_vk_destroy_buffer(vk_buffer& buf) {
if (buf == nullptr) {
return;
}
if (buf->device != nullptr) {
buf->device->memory_logger->log_deallocation(buf);
}
buf.reset();
}
void * ggml_vk_host_malloc(vk_device& device, size_t size) {
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
vk_buffer buf = ggml_vk_create_buffer(device, size,
{vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
size/1024.0/1024.0);
device->device.freeMemory(buf->device_memory);
device->device.destroyBuffer(buf->buffer);
return nullptr;
}
std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex);
device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
return buf->ptr;
}
void ggml_vk_host_free(vk_device& device, void* ptr) {
if (ptr == nullptr) {
return;
}
VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex);
vk_buffer buf;
size_t index;
for (size_t i = 0; i < device->pinned_memory.size(); i++) {
const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
if (ptr >= addr && ptr < endr) {
buf = std::get<2>(device->pinned_memory[i]);
index = i;
break;
}
}
if (buf == nullptr) {
fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n");
return;
}
ggml_vk_destroy_buffer(buf);
device->pinned_memory.erase(device->pinned_memory.begin() + index);
}
void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
std::shared_lock<std::shared_mutex> guard(device->pinned_memory_mutex);
buf = nullptr;
buf_offset = 0;
for (size_t i = 0; i < device->pinned_memory.size(); i++) {
const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
if (ptr >= addr && ptr < endr) {
buf = std::get<2>(device->pinned_memory[i]);
buf_offset = ((const uint8_t *)ptr) - addr;
break;
}
}
}
void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
if (device->sync_staging == nullptr || device->sync_staging->size < size) {
VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
ggml_vk_destroy_buffer(device->sync_staging);
device->sync_staging = ggml_vk_create_buffer_check(device, size,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
}
}
void ggml_vk_ensure_sync_staging_buffer(ggml_backend_vk_context * ctx, size_t size) {
if (ctx->sync_staging == nullptr || ctx->sync_staging->size < size) {
VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
ggml_vk_destroy_buffer(ctx->sync_staging);
ctx->sync_staging = ggml_vk_create_buffer_check(ctx->device, size,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
}
}
static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")");
GGML_ASSERT(!ggml_is_contiguous(tensor));
// Buffer is already mapped
if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl;
GGML_ABORT("fatal error");
}
// Check if src is pinned memory
vk_buffer buf = nullptr;
size_t buf_offset = 0;
ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset);
const uint64_t ne0 = tensor->ne[0];
const uint64_t ne1 = tensor->ne[1];
const uint64_t ne2 = tensor->ne[2];
const uint64_t ne3 = tensor->ne[3];
const uint64_t nb0 = tensor->nb[0];
const uint64_t nb1 = tensor->nb[1];
const uint64_t nb2 = tensor->nb[2];
const uint64_t nb3 = tensor->nb[3];
const ggml_type type = tensor->type;
const uint64_t ts = ggml_type_size(type);
const uint64_t bs = ggml_blck_size(type);
const uint64_t dstnb0 = ts;
const uint64_t dstnb1 = dstnb0*(ne0/bs);
const uint64_t dstnb2 = dstnb1*ne1;
const uint64_t dstnb3 = dstnb2*ne2;
const uint64_t ne = ggml_nelements(tensor);
if (buf != nullptr) {
// Memory is pinned, use as staging buffer
std::vector<vk::BufferCopy> slices;
for (uint64_t i3 = 0; i3 < ne3; i3++) {
for (uint64_t i2 = 0; i2 < ne2; i2++) {
// Find longest contiguous slice
if (ne1*nb1 == dstnb2) {
slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 });
} else {
for (uint64_t i1 = 0; i1 < ne1; i1++) {
if (ne0*nb0/bs == dstnb1) {
slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 });
} else {
const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
for (uint64_t i0 = 0; i0 < ne0; i0++) {
slices.push_back({ s_off + i0*nb0, d_off + i0*dstnb0, dstnb0 });
}
}
}
}
}
}
ggml_vk_sync_buffers(ctx, subctx);
subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);
return;
}
if (!sync_staging) {
GGML_ABORT("Asynchronous write to non-pinned memory not supported");
}
// Staging buffer required
vk_buffer& staging = ctx->device->sync_staging;
const uint64_t copy_size = ts*ne/bs;
ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
VkBufferCopy buf_copy{ 0, offset, copy_size };
ggml_vk_sync_buffers(ctx, subctx);
vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
for (uint64_t i3 = 0; i3 < ne3; i3++) {
for (uint64_t i2 = 0; i2 < ne2; i2++) {
// Find longest contiguous slice
if (ne1*nb1 == dstnb2) {
deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys);
} else {
for (uint64_t i1 = 0; i1 < ne1; i1++) {
if (ne0*nb0/bs == dstnb1) {
deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys);
} else {
const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
for (uint64_t i0 = 0; i0 < ne0; i0++) {
deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys);
}
}
}
}
}
}
}
bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging) {
VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
// Check if src is pinned memory
vk_buffer buf = nullptr;
size_t buf_offset = 0;
ggml_vk_host_get(dst->device, src, buf, buf_offset);
if (buf != nullptr) {
// Memory is pinned, use as staging buffer
std::vector<vk::BufferCopy> slices(1);
if (width == spitch && width == dpitch) {
// Only do single write if stride is equal
slices[0].srcOffset = buf_offset;
slices[0].dstOffset = offset;
slices[0].size = width * height;
} else {
slices.resize(height);
for (size_t i = 0; i < height; i++) {
slices[i].srcOffset = buf_offset + i * spitch;
slices[i].dstOffset = offset + i * dpitch;
slices[i].size = width;
}
}
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);
return true;
}
VK_LOG_DEBUG("STAGING");
if (!sync_staging) {
// copy was not handled caller needs to fall back
return false;
}
// Staging buffer required
const size_t staging_size = width * height;
ggml_vk_ensure_sync_staging_buffer(dst->device, staging_size);
vk_buffer& staging_buffer = dst->device->sync_staging;
std::vector<vk::BufferCopy> slices(1);
if (width == dpitch) {
slices[0].srcOffset = 0;
slices[0].dstOffset = offset;
slices[0].size = staging_size;
} else {
slices.resize(height);
for (size_t i = 0; i < height; i++) {
slices[i].srcOffset = i * width;
slices[i].dstOffset = offset + i * dpitch;
slices[i].size = width;
}
}
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer->buf.copyBuffer((VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, slices);
if (width == spitch) {
deferred_memcpy((uint8_t *)staging_buffer->ptr, src, staging_size, &subctx->in_memcpys);
} else {
for (size_t i = 0; i < height; i++) {
deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
}
}
return true;
}
bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging) {
VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, size, 1, sync_staging);
}
void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height) {
VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
// Buffer is already mapped
if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
if (width == spitch && width == dpitch) {
memcpy((uint8_t *)dst->ptr + offset, src, width * height);
} else {
for (size_t i = 0; i < height; i++) {
memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width);
}
}
} else {
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dst->device, subctx);
bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, dpitch, width, height, true);
GGML_ASSERT(ret);
ggml_vk_ctx_end(subctx);
for (auto& cpy : subctx->in_memcpys) {
memcpy(cpy.dst, cpy.src, cpy.n);
}
for (auto& mset : subctx->memsets) {
memset(mset.dst, mset.val, mset.n);
}
ggml_vk_submit(subctx, dst->device->fence);
VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
dst->device->device.resetFences({ dst->device->fence });
ggml_vk_queue_command_pools_cleanup(dst->device);
}
}
void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
ggml_vk_buffer_write_2d(dst, offset, src, size, size, size, 1);
}
bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging) {
VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
GGML_ASSERT(width > 0);
GGML_ASSERT(height > 0);
GGML_ASSERT(src != nullptr);
// TODO: staging_offset is not used
// Check if dst is pinned memory
vk_buffer buf = nullptr;
size_t buf_offset = 0;
ggml_vk_host_get(src->device, dst, buf, buf_offset);
std::vector<vk::BufferCopy> slices(1);
if (width == spitch && width == dpitch) {
// Only do single write if stride is equal
slices[0].srcOffset = offset;
slices[0].dstOffset = buf_offset;
slices[0].size = width * height;
} else {
slices.resize(height);
for (size_t i = 0; i < height; i++) {
slices[i].srcOffset = offset + i * spitch;
slices[i].dstOffset = buf_offset + i * dpitch;
slices[i].size = width;
}
}
if (buf != nullptr) {
// Memory is pinned, use as staging buffer
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer->buf.copyBuffer(src->buffer, buf->buffer, slices);
return true;
}
VK_LOG_DEBUG("STAGING");
if (!sync_staging) {
// copy was not handled caller needs to fall back
return false;
}
// Fall back to staging buffer
const size_t staging_size = width * height;
ggml_vk_ensure_sync_staging_buffer(src->device, staging_size);
vk_buffer& staging_buffer = src->device->sync_staging;
std::vector<vk::BufferCopy> staging_slices(1);
if (width == spitch) {
staging_slices[0].srcOffset = offset;
staging_slices[0].dstOffset = 0;
staging_slices[0].size = staging_size;
} else {
staging_slices.resize(height);
for (size_t i = 0; i < height; i++) {
staging_slices[i].srcOffset = offset + i * spitch;
staging_slices[i].dstOffset = i * width;
staging_slices[i].size = width;
}
}
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, staging_slices);
if (width == dpitch) {
deferred_memcpy(dst, staging_buffer->ptr, staging_size, &subctx->out_memcpys);
} else {
for (size_t i = 0; i < height; i++) {
deferred_memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) staging_buffer->ptr + i * width, width, &subctx->out_memcpys);
}
}
return true;
}
static bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) {
return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging);
}
void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height) {
VK_LOG_DEBUG("ggml_vk_buffer_read_2d(" << src->buffer << ", " << offset << ", " << width << ", " << height << ")");
// If the device is not an UMA device the memory is host-accessible through rebar. While writing
// through PCIe is sufficient fast reading back data from PCIe is slower than going through
// the HW device to host copy path.
if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {
GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
vk_context subctx = ggml_vk_create_temporary_context(src->device->compute_queue.cmd_pool);
ggml_vk_ctx_begin(src->device, subctx);
subctx->s->buffer->buf.pipelineBarrier(
vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer,
vk::PipelineStageFlagBits::eHost,
{},
{ { vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferWrite,
vk::AccessFlagBits::eHostRead } },
{}, {});
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, src->device->fence);
VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX),
"vk_buffer_read_2d uma waitForFences");
src->device->device.resetFences({ src->device->fence });
ggml_vk_queue_command_pools_cleanup(src->device);
if (width == spitch && width == dpitch) {
memcpy(dst, (const uint8_t *) src->ptr + offset, width * height);
} else {
for (size_t i = 0; i < height; i++) {
memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width);
}
}
} else {
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(src->device, subctx);
bool ret = ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, spitch, dpitch, width, height, true);
GGML_ASSERT(ret);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, src->device->fence);
VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read_2d waitForFences");
src->device->device.resetFences({ src->device->fence });
ggml_vk_queue_command_pools_cleanup(src->device);
for (auto& cpy : subctx->out_memcpys) {
memcpy(cpy.dst, cpy.src, cpy.n);
}
}
}
void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
ggml_vk_buffer_read_2d(src, offset, dst, size, size, size, 1);
}
void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
// Make sure both buffers are on same device
GGML_ASSERT(src->device == dst->device);
VkBufferCopy bc{ src_offset, dst_offset, size };
vkCmdCopyBuffer(ctx->s->buffer->buf, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
}
void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
if (src->device == dst->device) {
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
// Copy within the device
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(src->device, subctx);
ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, src->device->fence);
VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
src->device->device.resetFences({ src->device->fence });
ggml_vk_queue_command_pools_cleanup(src->device);
} else {
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
// Copy device to device
ggml_vk_ensure_sync_staging_buffer(src->device, size);
// Copy to src staging buffer
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
// Copy to dst buffer
ggml_vk_buffer_write(dst, dst_offset, src->device->sync_staging->ptr, size);
}
}
void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
dst->device->uma) {
deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets);
return;
}
// Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
ctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);
}
void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
dst->device->uma) {
memset((uint8_t*)dst->ptr + offset, c, size);
return;
}
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dst->device, subctx);
subctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dst->device->fence);
VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
dst->device->device.resetFences({ dst->device->fence });
ggml_vk_queue_command_pools_cleanup(dst->device);
}
ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
/* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
/* .get_base = */ ggml_backend_vk_buffer_get_base,
/* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
/* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
/* .set_tensor_2d = */ ggml_backend_vk_buffer_set_tensor_2d,
/* .get_tensor_2d = */ ggml_backend_vk_buffer_get_tensor_2d,
/* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
/* .clear = */ ggml_backend_vk_buffer_clear,
/* .reset = */ NULL,
};
vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
if (!device->external_memory_host) {
return {};
}
uintptr_t uptr = reinterpret_cast<uintptr_t>(ptr);
if (uptr & (device->min_imported_host_pointer_alignment - 1)) {
return {};
}
if (size & (device->min_imported_host_pointer_alignment - 1)) {
return {};
}
const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
vk_buffer buf {};
try {
buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr);
} catch (vk::SystemError& e) {
GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what());
}
return buf;
}
+228
View File
@@ -0,0 +1,228 @@
#pragma once
#include "ggml-vulkan-push-constants.h"
extern std::mutex queue_mutex;
extern ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface;
extern bool vk_memory_logger_enabled;
extern bool vk_perf_logger_enabled;
extern bool vk_perf_logger_concurrent;
extern bool vk_enable_sync_logger;
extern uint32_t vk_perf_logger_frequency;
extern std::string vk_pipeline_stats_filter;
extern void * const vk_ptr_base;
extern vk_instance_t vk_instance;
extern ggml_backend_buffer_i ggml_backend_vk_buffer_interface;
uint64_t vk_tensor_offset(const ggml_tensor * tensor);
uint32_t get_misalign_bytes(const ggml_backend_vk_context * ctx, const ggml_tensor * t);
void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx);
void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n);
void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
void ggml_vk_submit(vk_context& ctx, vk::Fence fence);
uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues);
void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only);
vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p);
vk_context ggml_vk_create_temporary_context(vk_command_pool& p);
void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p);
void ggml_vk_queue_command_pools_cleanup(vk_device& device);
vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0));
vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size);
void ggml_vk_destroy_buffer(vk_buffer& buf);
vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0);
void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx);
void ggml_vk_set_event(vk_context& ctx, vk::Event& event);
void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events);
vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc);
vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc, bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type);
uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch);
void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr);
vk_device ggml_vk_get_device(size_t idx);
DispatchLoaderDynamic & ggml_vk_default_dispatcher();
void ggml_vk_instance_init();
void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx);
vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type);
void * ggml_vk_host_malloc(vk_device& device, size_t size);
void ggml_vk_host_free(vk_device& device, void* ptr);
void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset);
vk_subbuffer ggml_vk_tensor_subbuffer( const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false);
void ggml_vk_ctx_end(vk_context& ctx);
void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx);
vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx);
bool ggml_vk_submit_transfer_ctx(ggml_backend_vk_context * ctx);
size_t ggml_vk_align_size(size_t width, size_t align);
void deferred_memcpy(void * dst, const void * src, size_t size, std::vector<vk_staging_memcpy>* memcpys = nullptr);
void deferred_memset(void * dst, uint32_t val, size_t size, std::vector<vk_staging_memset>* memsets = nullptr);
void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size);
void ggml_vk_ensure_sync_staging_buffer(ggml_backend_vk_context * ctx, size_t size);
bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false);
bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false);
void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height);
void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size);
bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false);
void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height);
void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size);
void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size);
void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size);
void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size);
void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size);
void ggml_vk_matmul( ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t padded_n);
bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor);
vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to);
void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, const vk_subbuffer & in, const vk_subbuffer & out);
vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type);
void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, const vk_subbuffer & in, const vk_subbuffer & out, uint32_t ne);
bool ggml_vk_can_use_fwht(const ggml_backend_vk_context * ctx, const ggml_tensor * src1, const ggml_tensor * dst);
void ggml_vk_fwht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src, ggml_tensor * dst);
void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx);
bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int node_idx);
void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx);
bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type);
bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type = GGML_TYPE_F16);
void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst);
void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx);
void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst);
void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst);
void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst);
void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst);
void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst);
void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx);
void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst);
void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst);
void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst);
void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst);
void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node);
void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params);
void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst);
void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx);
void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop);
void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_col2im_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_snake_dispatch_fused(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx);
void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst);
void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit);
void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx);
void ggml_vk_cleanup(ggml_backend_vk_context * ctx);
int ggml_vk_get_device_count();
void ggml_vk_get_device_description(int device, char * description, size_t description_size);
bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer);
void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer);
void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer);
enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor);
void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void ggml_backend_vk_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size);
void ggml_backend_vk_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst);
void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value);
const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft);
ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft);
size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft);
size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor);
void ggml_backend_vk_free(ggml_backend_t backend);
void ggml_vk_synchronize(ggml_backend_vk_context * ctx);
bool ggml_vk_is_empty(ggml_tensor * node);
bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops);
bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, int num_extra);
bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, topk_moe_mode mode);
bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx);
bool ggml_vk_can_fuse_snake(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx);
bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise);
bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx);
uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx);
void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph);
int64_t ggml_vk_get_op_batch_size(const ggml_tensor * op);
vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size);
ggml_backend_reg_t ggml_backend_vk_reg();
bool ggml_vk_instance_layer_settings_available();
bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev);
template <typename T>
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, const T &push_constants, std::array<uint32_t, 3> elements) {
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {";
for (auto& buffer : descriptor_buffer_infos) {
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
}
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
GGML_ASSERT(pipeline->push_constant_size == push_constant_size(push_constants));
vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
subctx->s->buffer->buf.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
subctx->s->buffer->buf.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
subctx->s->buffer->buf.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
pipeline->layout,
0,
{ descriptor_set },
{});
subctx->s->buffer->buf.dispatch(wg0, wg1, wg2);
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,293 @@
#include "ggml-vulkan-common.h"
void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
if (sinks) {
std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
}
std::cerr << "))");
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const uint32_t nem0 = mask ? mask->ne[0] : 0;
const uint32_t nem1 = mask ? mask->ne[1] : 0;
const uint32_t nem2 = mask ? mask->ne[2] : 0;
const uint32_t nem3 = mask ? mask->ne[3] : 0;
const uint32_t HSK = nek0;
const uint32_t HSV = nev0;
uint32_t N = neq1;
const uint32_t KV = nek1;
GGML_ASSERT(ne0 == HSV);
GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
GGML_ASSERT(neq0 == HSK);
GGML_ASSERT(neq1 == N);
GGML_ASSERT(nev1 == nek1);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
assert(dst->type == GGML_TYPE_F32);
assert(q->type == GGML_TYPE_F32);
uint32_t gqa_ratio = 1;
uint32_t qk_ratio = neq2 / nek2;
uint32_t workgroups_x = (uint32_t)neq1;
uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3;
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32 || k->type == GGML_TYPE_BF16;
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc);
const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2.
gqa_ratio = qk_ratio;
N = gqa_ratio;
workgroups_y /= gqa_ratio;
}
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
if (k->type == GGML_TYPE_F32) {
k_stride /= 4;
}
if (v->type == GGML_TYPE_F32) {
v_stride /= 4;
}
const uint32_t alignment = tuning_params.block_cols;
bool aligned = (KV % alignment) == 0 &&
// the "aligned" shader variant will forcibly align strides, for performance
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
// Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) {
aligned = false;
}
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type);
vk_pipeline pipeline = nullptr;
{
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16;
auto it = pipelines.find(fa_pipeline_state);
if (it != pipelines.end()) {
pipeline = it->second;
} else {
pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
}
}
assert(pipeline);
// Compile early to initialize wg_denoms.
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
uint32_t split_kv = KV;
uint32_t split_k = 1;
// Intel Alchemist prefers more workgroups
const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1;
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
const uint32_t Br = fa_pipeline_state.Br;
const uint32_t Bc = fa_pipeline_state.Bc;
GGML_ASSERT(Br == pipeline->wg_denoms[0]);
const uint32_t Tr = CEIL_DIV(N, Br);
// Try to use split_k when KV is large enough to be worth the overhead.
if (gqa_ratio > 1 && workgroups_x <= Br) {
split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
} else if (gqa_ratio <= 1) {
uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z;
if (total_wgs_no_split < shader_core_count * 2) {
split_k = shader_core_count * 2 / total_wgs_no_split;
}
}
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
split_k = CEIL_DIV(KV, split_kv);
}
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
// For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
// For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
GGML_ABORT("Requested preallocation size is too large");
}
if (ctx->prealloc_size_split_k < split_k_size) {
ctx->prealloc_size_split_k = split_k_size;
ggml_vk_preallocate_buffers(ctx, subctx);
}
const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
vk_pipeline pipeline_fa_mask_opt = nullptr;
if (use_mask_opt) {
{
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto &pipelines = ctx->device->pipeline_fa_mask_opt;
auto it = pipelines.find({Br, Bc});
if (it != pipelines.end()) {
pipeline_fa_mask_opt = it->second;
} else {
pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
}
}
assert(pipeline_fa_mask_opt);
ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
if (ctx->prealloc_size_y < mask_opt_size) {
ctx->prealloc_size_y = mask_opt_size;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
const uint32_t n_head_kv = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
vk_subbuffer q_buf = ggml_vk_tensor_subbuffer(ctx, q);
vk_subbuffer k_buf = ggml_vk_tensor_subbuffer(ctx, k);
vk_subbuffer v_buf = ggml_vk_tensor_subbuffer(ctx, v);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
if (use_mask_opt)
{
const vk_op_flash_attn_mask_opt_push_constants opt_pc = {
nem0,
nem1,
nem2,
(uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),
(uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),
(uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),
mask_opt_num_dwords,
mask_opt_num_dwords * CEIL_DIV(nem1, Br),
mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,
{ mask_buf, mask_opt_buf }, opt_pc,
{ mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });
ggml_vk_sync_buffers(ctx, subctx);
}
const vk_flash_attn_push_constants pc = { N, KV,
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
(uint32_t)neq2, (uint32_t)neq3,
(uint32_t)nek2, (uint32_t)nek3,
(uint32_t)nev2, (uint32_t)nev3,
nem1, nem2, nem3,
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
scale, max_bias, logit_softcap,
mask_n_head_log2, m0, m1,
gqa_ratio, split_kv, split_k };
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
// We reuse workgroups_x to mean the number of splits, so we need to
// cancel out the divide by wg_denoms[0].
uint32_t dispatch_x;
if (gqa_ratio > 1) {
workgroups_x *= pipeline->wg_denoms[0];
dispatch_x = split_k * workgroups_x;
} else {
dispatch_x = Tr * split_k * pipeline->wg_denoms[0];
}
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
pc, { dispatch_x, workgroups_y, workgroups_z });
ggml_vk_sync_buffers(ctx, subctx);
const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{split_k_buf, sinks_buf, dst_buf},
pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
ctx->prealloc_split_k_need_sync = true;
} else {
if (gqa_ratio > 1) {
// When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
workgroups_x *= pipeline->wg_denoms[0];
}
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},
pc, { workgroups_x, workgroups_y, workgroups_z });
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,512 @@
#include "ggml-vulkan-common.h"
std::mutex queue_mutex;
void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
if (tensor->view_src) {
return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base;
}
return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
}
uint32_t get_misalign_bytes(const ggml_backend_vk_context * ctx, const ggml_tensor * t)
{
return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
}
static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) {
const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset},
VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
return range;
}
void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
// Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
// during this wait.
if (ctx->almost_ready_fence_pending) {
VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence");
ctx->device->device.resetFences({ ctx->almost_ready_fence });
ctx->almost_ready_fence_pending = false;
}
// Spin (w/pause) waiting for the graph to finish executing.
vk::Result result;
while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) {
if (result != vk::Result::eNotReady) {
fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__);
exit(1);
}
for (uint32_t i = 0; i < 100; ++i) {
YIELD();
YIELD();
YIELD();
YIELD();
YIELD();
YIELD();
YIELD();
YIELD();
YIELD();
YIELD();
}
}
ctx->device->device.resetFences({ ctx->fence });
}
void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) {
VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
ctx->pipeline_descriptor_set_requirements += n;
if (!pipeline->compiled) {
ggml_vk_load_shaders(ctx->device, pipeline);
}
ggml_pipeline_allocate_descriptor_sets(ctx);
}
void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) {
if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) {
// Enough descriptors are available
return;
}
vk_device& device = ctx->device;
// Grow by 50% to avoid frequent allocations
uint32_t needed = std::max(3 * ctx->descriptor_sets.size() / 2, size_t{ctx->pipeline_descriptor_set_requirements});
uint32_t to_alloc = needed - ctx->descriptor_sets.size();
uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
while (to_alloc > 0) {
const uint32_t alloc_count = std::min(pool_remaining, to_alloc);
to_alloc -= alloc_count;
pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
if (pool_idx >= ctx->descriptor_pools.size()) {
vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
}
std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
for (uint32_t i = 0; i < alloc_count; i++) {
layouts[i] = device->dsl;
}
vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data());
std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end());
pool_idx++;
}
}
static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
vk::CommandBufferAllocateInfo command_buffer_alloc_info(
p.pool,
vk::CommandBufferLevel::ePrimary,
1);
const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
p.cmd_buffers.push_back({ cmd_buffers.front(), 0, true });
return &p.cmd_buffers[p.cmd_buffers.size()-1];
}
void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
if (ctx->seqs.empty()) {
if (fence) {
std::lock_guard<std::mutex> guard(queue_mutex);
ctx->p->q->queue.submit({}, fence);
}
return;
}
VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")");
std::vector<std::vector<uint64_t>> tl_wait_vals;
std::vector<std::vector<uint64_t>> tl_signal_vals;
std::vector<std::vector<vk::Semaphore>> tl_wait_semaphores;
std::vector<std::vector<vk::Semaphore>> tl_signal_semaphores;
std::vector<vk::TimelineSemaphoreSubmitInfo> tl_submit_infos;
std::vector<vk::SubmitInfo> submit_infos;
int idx = -1;
std::vector<std::vector<vk::PipelineStageFlags>> stage_flags;
size_t reserve = 0;
for (const auto& sequence : ctx->seqs) {
reserve += sequence.size();
}
// Pre-reserve vectors to prevent reallocation, which invalidates pointers
tl_wait_semaphores.reserve(reserve);
tl_wait_vals.reserve(reserve);
tl_signal_semaphores.reserve(reserve);
tl_signal_vals.reserve(reserve);
tl_submit_infos.reserve(reserve);
submit_infos.reserve(reserve);
stage_flags.reserve(reserve);
for (const auto& sequence : ctx->seqs) {
for (const auto& submission : sequence) {
stage_flags.push_back({});
idx++;
tl_wait_vals.push_back({});
tl_wait_semaphores.push_back({});
tl_signal_vals.push_back({});
tl_signal_semaphores.push_back({});
for (size_t i = 0; i < submission.wait_semaphores.size(); i++) {
stage_flags[idx].push_back(ctx->p->q->stage_flags);
tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value);
tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s);
}
for (size_t i = 0; i < submission.signal_semaphores.size(); i++) {
tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value);
tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s);
}
tl_submit_infos.push_back({
(uint32_t) submission.wait_semaphores.size(),
tl_wait_vals[idx].data(),
(uint32_t) submission.signal_semaphores.size(),
tl_signal_vals[idx].data(),
});
tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo;
tl_submit_infos[idx].pNext = nullptr;
vk::SubmitInfo si{
(uint32_t) submission.wait_semaphores.size(),
tl_wait_semaphores[idx].data(),
stage_flags[idx].data(),
1,
&submission.buffer->buf,
(uint32_t) submission.signal_semaphores.size(),
tl_signal_semaphores[idx].data(),
};
si.setPNext(&tl_submit_infos[idx]);
submit_infos.push_back(si);
}
}
std::lock_guard<std::mutex> guard(queue_mutex);
ctx->p->q->queue.submit(submit_infos, fence);
ctx->seqs.clear();
}
uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) {
VK_LOG_DEBUG("ggml_vk_find_queue_family_index()");
const uint32_t qfsize = queue_family_props.size();
// Try with avoid preferences first
for (uint32_t i = 0; i < qfsize; i++) {
if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) {
return i;
}
}
// Fall back to only required
for (size_t i = 0; i < qfsize; i++) {
if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) {
return i;
}
}
// Fall back to reusing compute queue
for (size_t i = 0; i < qfsize; i++) {
if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) {
return i;
}
}
// Fall back to ignoring min_num_queries
for (size_t i = 0; i < qfsize; i++) {
if (queue_family_props[i].queueFlags & required) {
return i;
}
}
// All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations.
// Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional.
if (compute_index >= 0) {
return compute_index;
}
std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl;
for(auto &q_family : queue_family_props) {
std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl;
}
abort();
}
void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
VK_LOG_DEBUG("ggml_vk_create_queue()");
std::lock_guard<std::recursive_mutex> guard(device->mutex);
q.queue_family_index = queue_family_index;
q.transfer_only = transfer_only;
q.cmd_pool.init(device, &q);
q.queue = device->device.getQueue(queue_family_index, queue_index);
q.stage_flags = stage_flags;
}
vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) {
vk_context result = std::make_shared<vk_context_struct>();
VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")");
ctx->gc.contexts.emplace_back(result);
result->p = &p;
return result;
}
vk_context ggml_vk_create_temporary_context(vk_command_pool& p) {
vk_context result = std::make_shared<vk_context_struct>();
VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")");
result->p = &p;
return result;
}
static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) {
VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
vk::SemaphoreCreateInfo ci{};
ci.setPNext(&tci);
vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
ctx->gc.semaphores.push_back({ semaphore, 0 });
return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
}
static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) {
VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) {
vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
vk::SemaphoreCreateInfo ci{};
ci.setPNext(&tci);
vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
ctx->gc.tl_semaphores.push_back({ semaphore, 0 });
}
return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
}
static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
if (ctx->event_idx >= ctx->gc.events.size()) {
ctx->gc.events.push_back(ctx->device->device.createEvent({}));
}
return ctx->gc.events[ctx->event_idx++];
}
void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) {
VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()");
// Requires command buffers to be done
device->device.resetCommandPool(p.pool);
// Don't clear the command buffers and mark them as not in use.
// This allows us to reuse them
for (auto& cmd_buffer : p.cmd_buffers) {
cmd_buffer.in_use = false;
}
}
void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()");
// Arbitrary frequency to cleanup/reuse command buffers
static constexpr uint32_t cleanup_frequency = 10;
if (device->compute_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {
ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool);
}
if (device->transfer_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {
ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool);
}
}
vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset) {
return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) };
}
void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
VK_LOG_DEBUG("ggml_vk_sync_buffers()");
const bool transfer_queue = subctx->p->q->transfer_only;
if (ctx) {
ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
}
subctx->s->buffer->buf.pipelineBarrier(
subctx->p->q->stage_flags,
subctx->p->q->stage_flags,
{},
{ {
{ !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
{ !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }
} },
{},
{}
);
}
static void ggml_vk_reset_event(vk_context& ctx, vk::Event& event) {
VK_LOG_DEBUG("ggml_vk_set_event()");
ctx->s->buffer->buf.resetEvent(
event,
ctx->p->q->stage_flags
);
}
void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
VK_LOG_DEBUG("ggml_vk_set_event()");
ctx->s->buffer->buf.setEvent(
event,
ctx->p->q->stage_flags
);
}
void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
VK_LOG_DEBUG("ggml_vk_wait_events()");
if (events.empty()) {
return;
}
ctx->s->buffer->buf.waitEvents(
events,
ctx->p->q->stage_flags,
ctx->p->q->stage_flags,
{},
{},
{}
);
}
vk_subbuffer ggml_vk_tensor_subbuffer(
const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign) {
vk_buffer buffer = nullptr;
size_t offset = 0;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
}
if (!buffer) {
auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
buffer = buf_ctx->dev_buffer;
offset = vk_tensor_offset(tensor) + tensor->view_offs;
}
GGML_ASSERT(buffer != nullptr);
size_t size = ggml_nbytes(tensor);
size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
// The shader must support misaligned offsets when indexing into the buffer
GGML_ASSERT(allow_misalign || misalign_bytes == 0);
offset &= ~misalign_bytes;
size += misalign_bytes;
return vk_subbuffer{buffer, offset, size};
}
static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) {
for (auto& cmd_buffer : pool.cmd_buffers) {
if (!cmd_buffer.in_use) {
cmd_buffer.use_counter++;
cmd_buffer.in_use = true;
return &cmd_buffer;
}
}
return ggml_vk_create_cmd_buffer(device, pool);
}
static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
vk_submission s;
s.buffer = ggml_vk_get_or_create_cmd_buffer(device, p);
if (one_time) {
s.buffer->buf.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
} else {
s.buffer->buf.begin({ vk::CommandBufferUsageFlags{} });
}
return s;
}
void ggml_vk_ctx_end(vk_context& ctx) {
VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
if (ctx->s == nullptr) {
return;
}
ctx->s->buffer->buf.end();
ctx->s = nullptr;
}
void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")");
if (subctx->s != nullptr) {
ggml_vk_ctx_end(subctx);
}
subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) });
subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
}
vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) {
vk_context result;
if (!ctx->compute_ctx.expired()) {
result = ctx->compute_ctx.lock();
} else {
result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->compute_ctx = result;
ggml_vk_ctx_begin(ctx->device, result);
}
if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
result->s->wait_semaphores.push_back(ctx->transfer_semaphore);
ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
}
return result;
}
bool ggml_vk_submit_transfer_ctx(ggml_backend_vk_context * ctx) {
if (!ctx->device->async_use_transfer_queue || ctx->transfer_ctx.expired()) {
return false;
}
vk_context cpy_ctx = ctx->transfer_ctx.lock();
ggml_vk_ctx_end(cpy_ctx);
for (auto& cpy : cpy_ctx->in_memcpys) {
memcpy(cpy.dst, cpy.src, cpy.n);
}
ctx->transfer_semaphore.value++;
cpy_ctx->seqs.back().back().signal_semaphores.push_back(ctx->transfer_semaphore);
ggml_vk_submit(cpy_ctx, {});
ctx->transfer_ctx.reset();
return true;
}
size_t ggml_vk_align_size(size_t width, size_t align) {
VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
return CEIL_DIV(width, align) * align;
}
void deferred_memcpy(void * dst, const void * src, size_t size, std::vector<vk_staging_memcpy>* memcpys) {
if (memcpys == nullptr) {
memcpy(dst, src, size);
} else {
memcpys->emplace_back(dst, src, size);
}
}
void deferred_memset(void * dst, uint32_t val, size_t size, std::vector<vk_staging_memset>* memsets) {
if (memsets == nullptr) {
memset(dst, val, size);
} else {
memsets->emplace_back(dst, val, size);
}
}
@@ -0,0 +1,872 @@
#pragma once
#include "ggml-vulkan-types.h"
uint32_t get_misalign_bytes(const ggml_backend_vk_context * ctx, const ggml_tensor * t);
struct vk_mat_mat_push_constants {
uint32_t M; uint32_t N; uint32_t K;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t base_work_group_z; uint32_t num_batches;
uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
};
struct vk_mat_vec_push_constants {
uint32_t ncols;
uint32_t stride_a;
uint32_t stride_b;
uint32_t stride_d;
uint32_t batch_stride_a;
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t fusion_flags;
uint32_t base_work_group_y;
uint32_t ne02;
uint32_t ne12;
uint32_t broadcast2;
uint32_t broadcast3;
};
struct vk_mat_vec_p021_push_constants {
uint32_t ncols_x;
uint32_t nrows_x;
uint32_t nchannels_x;
uint32_t nchannels_y;
uint32_t b_offset;
uint32_t d_offset;
uint32_t fusion_flags;
};
struct vk_mat_vec_nc_push_constants {
uint32_t ncols_x;
uint32_t nrows_x;
uint32_t row_stride_x;
uint32_t channel_stride_x;
uint32_t channel_stride_y;
uint32_t channel_x_divisor;
uint32_t ne12;
uint32_t b_offset;
uint32_t d_offset;
uint32_t nb03;
uint32_t nb13;
uint32_t nb23;
uint32_t fusion_flags;
};
struct vk_mat_mat_id_push_constants {
uint32_t M; uint32_t N; uint32_t K;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
uint32_t padded_N;
};
struct vk_mat_vec_id_push_constants {
uint32_t ncols;
uint32_t stride_a;
uint32_t stride_b;
uint32_t stride_d;
uint32_t batch_stride_a;
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t fusion_flags;
uint32_t nei0;
uint32_t ne11;
uint32_t expert_i1;
uint32_t nbi1;
};
struct vk_flash_attn_push_constants {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask_n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
};
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
struct vk_op_push_constants {
uint32_t KX;
uint32_t KY;
float param1;
float param2;
float param3;
float param4;
};
struct vk_op_fwht_push_constants {
uint32_t n_rows;
uint32_t src_offset;
uint32_t dst_offset;
float scale;
};
struct vk_op_count_experts_push_constants {
uint32_t ne00;
uint32_t ne01;
uint32_t nb00;
uint32_t nb01;
uint32_t a_offset;
};
struct vk_op_glu_push_constants {
uint32_t N;
uint32_t ne00;
uint32_t ne20;
uint32_t mode; // 0: default, 1: swapped, 2: split
float alpha; // for swiglu_oai
float limit;
uint32_t nb00;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb20;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t ne21;
uint32_t ne22;
uint32_t misalign_offsets;
uint32_t ne2_012mp; uint32_t ne2_012L;
uint32_t ne2_01mp; uint32_t ne2_01L;
uint32_t ne2_0mp; uint32_t ne2_0L;
};
static_assert(sizeof(vk_op_glu_push_constants) <= 128, "sizeof(vk_op_glu_push_constants) must be <= 128");
struct vk_op_unary_push_constants {
uint32_t ne;
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
uint32_t misalign_offsets;
float param1; float param2; float param3; float param4;
uint32_t ne0_012mp; uint32_t ne0_01mp; uint32_t ne0_0mp; uint32_t ne0_Ls;
uint32_t ne1_012mp; uint32_t ne1_01mp; uint32_t ne1_0mp; uint32_t ne1_Ls;
};
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
ne = ne != 0 ? ne : ggml_nelements(dst);
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
vk_op_unary_push_constants p{};
p.ne = (uint32_t)ne;
size_t src0_tsize = ggml_type_size(src0->type);
p.ne00 = (uint32_t)src0->ne[0];
p.ne01 = (uint32_t)src0->ne[1];
p.ne02 = (uint32_t)src0->ne[2];
p.ne03 = (uint32_t)src0->ne[3];
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
size_t dst_tsize = ggml_type_size(dst->type);
p.ne10 = (uint32_t)dst->ne[0];
p.ne11 = (uint32_t)dst->ne[1];
p.ne12 = (uint32_t)dst->ne[2];
p.ne13 = (uint32_t)dst->ne[3];
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
return p; // offsets are initialized later in ggml_vk_op
}
struct vk_op_pad_push_constants {
uint32_t ne;
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
uint32_t misalign_offsets;
uint32_t circular;
uint32_t lp0; uint32_t rp0;
uint32_t lp1; uint32_t rp1;
uint32_t lp2; uint32_t rp2;
uint32_t lp3; uint32_t rp3;
};
static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) {
int64_t ne = ggml_nelements(dst);
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
vk_op_pad_push_constants p{};
p.ne = (uint32_t)ne;
size_t src0_tsize = ggml_type_size(src0->type);
p.ne00 = (uint32_t)src0->ne[0];
p.ne01 = (uint32_t)src0->ne[1];
p.ne02 = (uint32_t)src0->ne[2];
p.ne03 = (uint32_t)src0->ne[3];
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
size_t dst_tsize = ggml_type_size(dst->type);
p.ne10 = (uint32_t)dst->ne[0];
p.ne11 = (uint32_t)dst->ne[1];
p.ne12 = (uint32_t)dst->ne[2];
p.ne13 = (uint32_t)dst->ne[3];
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
p.lp0 = dst->op_params[0];
p.rp0 = dst->op_params[1];
p.lp1 = dst->op_params[2];
p.rp1 = dst->op_params[3];
p.lp2 = dst->op_params[4];
p.rp2 = dst->op_params[5];
p.lp3 = dst->op_params[6];
p.rp3 = dst->op_params[7];
p.circular = dst->op_params[8];
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
}
static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L)
{
// compute L = ceil(log2(d));
L = 0;
while (L < 32 && (uint32_t{1} << L) < d) {
L++;
}
mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1);
}
static uint32_t pack_fastdiv_L(uint32_t L0, uint32_t L1, uint32_t L2) {
return L0 | (L1 << 8) | (L2 << 16);
}
template <typename T> void init_pushconst_fastdiv(T &p) {
GGML_UNUSED(p);
static_assert(!std::is_const<T>::value, "unexpected type");
}
template <> inline void init_pushconst_fastdiv(vk_op_unary_push_constants &p) {
// Compute magic values to divide by these six numbers.
uint32_t ne0_012L;
uint32_t ne0_01L;
uint32_t ne0_0L;
uint32_t ne1_012L;
uint32_t ne1_01L;
uint32_t ne1_0L;
init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, ne0_012L);
init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, ne0_01L);
init_fastdiv_values(p.ne00, p.ne0_0mp, ne0_0L);
init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, ne1_012L);
init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, ne1_01L);
init_fastdiv_values(p.ne10, p.ne1_0mp, ne1_0L);
p.ne0_Ls = pack_fastdiv_L(ne0_012L, ne0_01L, ne0_0L);
p.ne1_Ls = pack_fastdiv_L(ne1_012L, ne1_01L, ne1_0L);
}
template <> inline void init_pushconst_fastdiv(vk_op_glu_push_constants &p) {
// GLU linearizes over dst, then uses dst coordinates for src0/src1.
init_fastdiv_values(p.ne22*p.ne21*p.ne20, p.ne2_012mp, p.ne2_012L);
init_fastdiv_values(p.ne21*p.ne20, p.ne2_01mp, p.ne2_01L);
init_fastdiv_values(p.ne20, p.ne2_0mp, p.ne2_0L);
}
struct vk_op_binary_push_constants {
uint32_t ne;
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
uint32_t misalign_offsets;
float param1; float param2; int32_t param3;
};
struct vk_op_multi_add_push_constants {
// shape for dst
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
// strides for srcs+dst
uint32_t nb[MAX_PARAMETER_COUNT][4];
uint32_t rms_partials;
};
static_assert(MAX_PARAMETER_COUNT == 12);
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_experts_push;
uint32_t n_expert_used;
float clamp_min;
float clamp_max;
uint32_t gating_func;
uint32_t has_bias;
uint32_t with_norm;
float output_scale;
float output_bias;
};
struct vk_op_add_id_push_constants {
uint32_t ne0;
uint32_t ne1;
uint32_t s01;
uint32_t s02;
uint32_t s11;
uint32_t s21;
};
struct vk_op_diag_mask_push_constants {
uint32_t ncols;
uint32_t rows_per_channel;
int32_t n_past;
};
struct vk_op_rope_push_constants {
uint32_t rope_mode;
uint32_t nrows;
uint32_t n_dims;
float freq_scale;
float freq_base;
float ext_factor;
float attn_factor;
float corr_dims[2];
float theta_scale;
uint32_t has_ff;
int32_t sections[4];
uint32_t is_imrope;
uint32_t is_back;
uint32_t set_rows_stride;
uint32_t ne00;
uint32_t ne01;
uint32_t ne02;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t a_offset;
uint32_t d_offset;
};
static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
struct vk_op_rms_norm_mul_rope_push_constants {
vk_op_binary_push_constants bin;
vk_op_rope_push_constants rope;
};
struct vk_op_soft_max_push_constants {
uint32_t KX;
uint32_t KY;
uint32_t ne00;
uint32_t ne01;
uint32_t ne02;
uint32_t ne12;
uint32_t ne13;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
float scale;
float max_bias;
float m0;
float m1;
uint32_t n_head_log2;
uint32_t nrows_x;
uint32_t has_sinks;
};
struct vk_op_argsort_push_constants {
uint32_t ncols;
uint32_t ncols_padded;
uint32_t ncols_padded_log2;
uint32_t nrows;
uint32_t order;
uint32_t outer_start;
uint32_t outer_end;
uint32_t inner_start;
uint32_t inner_end;
};
struct vk_op_topk_push_constants {
uint32_t orig_ncols;
uint32_t ncols_input;
uint32_t ncols_output;
uint32_t k;
uint32_t nrows;
uint32_t first_pass;
uint32_t last_pass;
};
struct vk_op_im2col_push_constants {
uint64_t dst_addr;
uint32_t batch_offset; uint32_t offset_delta;
uint32_t IC;
uint32_t IW; uint32_t IH;
uint32_t OW; uint32_t OH;
uint32_t KW; uint32_t KH;
uint32_t OH_batch;
uint32_t CHW;
int32_t s0; int32_t s1;
int32_t p0; int32_t p1;
int32_t d0; int32_t d1;
uint32_t batch_IC;
};
struct vk_op_im2col_3d_push_constants {
uint64_t dst_addr;
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t s0;
uint32_t s1;
uint32_t s2;
uint32_t p0;
uint32_t p1;
uint32_t p2;
uint32_t d0;
uint32_t d1;
uint32_t d2;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t IC;
uint32_t KW;
uint32_t OH;
uint32_t KD_KH_KW;
uint32_t KH_KW;
uint32_t IC_KD_KH_KW;
uint32_t N_OD_OH;
uint32_t OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW;
uint32_t misalign_offsets;
};
struct vk_op_timestep_embedding_push_constants {
uint32_t nb1;
uint32_t dim;
uint32_t max_period;
};
struct vk_op_col2im_1d_push_constants {
uint32_t T_out;
uint32_t OC;
uint32_t K_OC;
uint32_t T_in;
uint32_t K;
int32_t stride;
int32_t p0;
};
struct vk_op_conv_transpose_1d_push_constants {
uint32_t Cout;
uint32_t Cin;
uint32_t K;
uint32_t L;
uint32_t KL;
uint32_t nb01;
uint32_t nb02;
uint32_t nb11;
uint32_t nb1;
int32_t s0;
};
struct vk_op_snake_push_constants {
uint32_t ne0;
uint32_t ne1;
};
struct vk_op_pool2d_push_constants {
uint32_t IW; uint32_t IH;
uint32_t OW; uint32_t OH;
uint32_t OC;
uint32_t pelements;
uint32_t op;
int32_t k0; int32_t k1;
int32_t s0; int32_t s1;
int32_t p0; int32_t p1;
};
struct vk_op_rwkv_wkv6_push_constants {
uint32_t B;
uint32_t T;
uint32_t C;
uint32_t H;
};
struct vk_op_rwkv_wkv7_push_constants {
uint32_t B;
uint32_t T;
uint32_t C;
uint32_t H;
};
struct vk_op_gated_delta_net_push_constants {
uint32_t H;
uint32_t n_tokens;
uint32_t n_seqs;
uint32_t s_off;
uint32_t sq1, sq2, sq3;
uint32_t sv1, sv2, sv3;
uint32_t sb1, sb2, sb3;
uint32_t neq1, rq3;
float scale;
uint32_t K;
};
struct vk_op_ssm_scan_push_constants {
uint32_t nb02, nb03, nb12, nb13;
uint32_t nb21, nb22, nb31;
uint32_t nb42, nb43, nb52, nb53;
uint32_t s_off;
uint32_t n_head, d_head, n_group, n_tok;
};
struct vk_op_ssm_conv_push_constants {
uint32_t nb01, nb02;
uint32_t nb11;
uint32_t dst_nb0, dst_nb1, dst_nb2;
uint32_t nc, ncs, nr, n_t, n_s;
};
struct vk_op_conv2d_push_constants {
uint32_t Cout;
uint32_t Cin;
uint32_t N;
uint32_t W;
uint32_t H;
uint32_t OW;
uint32_t OH;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// init_fastdiv_values constants for dividing by OW, OW*OH
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
};
template <> inline void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
// Compute magic values to divide by OW, OW*OH
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
}
struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
uint32_t channels;
uint32_t dst_w;
uint32_t dst_h;
uint32_t src_w;
uint32_t src_h;
uint32_t knl_w;
uint32_t knl_h;
int32_t stride_x;
int32_t stride_y;
int32_t pad_x;
int32_t pad_y;
int32_t dilation_x;
int32_t dilation_y;
};
struct vk_op_upscale_push_constants {
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
uint32_t ne00; uint32_t ne01;
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
float sf0; float sf1; float sf2; float sf3;
float pixel_offset;
};
struct vk_op_sum_rows_push_constants
{
uint32_t n_cols;
uint32_t ne01, ne02;
uint32_t nb01, nb02, nb03;
uint32_t nb11, nb12, nb13;
float weight;
uint32_t misalign_offsets;
uint32_t ne0_12mp, ne0_12L;
uint32_t ne0_1mp, ne0_1L;
};
static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
uint32_t type_size = (uint32_t)ggml_type_size(src->type);
vk_op_sum_rows_push_constants p = {};
p.n_cols = (uint32_t)n_cols;
p.ne01 = (uint32_t)src->ne[1];
p.ne02 = (uint32_t)src->ne[2];
p.nb01 = (uint32_t)src->nb[1] / type_size;
p.nb02 = (uint32_t)src->nb[2] / type_size;
p.nb03 = (uint32_t)src->nb[3] / type_size;
p.nb11 = (uint32_t)dst->nb[1] / type_size;
p.nb12 = (uint32_t)dst->nb[2] / type_size;
p.nb13 = (uint32_t)dst->nb[3] / type_size;
p.weight = 1.0f;
return p;
}
template <> inline void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L);
init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L);
}
struct vk_quantize_q8_1_push_constants {
uint32_t ne;
uint32_t num_blocks;
};
struct vk_op_flash_attn_split_k_reduce_push_constants {
uint32_t D;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t k_num;
uint32_t sinks;
};
struct vk_op_flash_attn_mask_opt_push_constants {
uint32_t nem0;
uint32_t nem1;
uint32_t nem2;
uint32_t nbm1;
uint32_t nbm2;
uint32_t nbm3;
uint32_t nbd1;
uint32_t nbd2;
uint32_t nbd3;
};
template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
GGML_UNUSED(p);
GGML_UNUSED(src0);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
GGML_UNUSED(dst);
static_assert(!std::is_const<T>::value, "unexpected type");
GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0);
GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0);
}
template <> inline void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_mat_vec_p021_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.b_offset = b_offset;
p.d_offset = d_offset;
GGML_UNUSED(src0);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
template <> inline void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_mat_vec_nc_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.b_offset = b_offset;
p.d_offset = d_offset;
GGML_UNUSED(src0);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
template <> inline void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_fwht_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
p.src_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
p.dst_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
template <typename T> size_t push_constant_size(const T &t) {
static_assert(std::is_class<T>::value, "T must be a struct/class");
GGML_UNUSED(t);
return sizeof(T);
}
template <typename T> size_t push_constant_size(const std::vector<T> &t) {
GGML_UNUSED(t);
return sizeof(T) * t.size();
}
template <typename T, uint32_t N> size_t push_constant_size(const std::array<T, N> &t) {
GGML_UNUSED(t);
return sizeof(T) * N;
}
template <typename T> const T *push_constant_data(const T &t) {
static_assert(std::is_class<T>::value, "T must be a struct/class");
return &t;
}
template <typename T> const T *push_constant_data(const std::vector<T> &t) {
return t.data();
}
template <typename T, uint32_t N> const T *push_constant_data(const std::array<T, N> &t) {
return t.data();
}
template <> inline void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.misalign_offsets = (a_offset << 16) | d_offset;
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
template <> inline void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_glu_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t b_offset = src1 ? get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type) : a_offset;
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
GGML_ASSERT(a_offset < (1u << 8));
GGML_ASSERT(b_offset < (1u << 8));
GGML_ASSERT(d_offset < (1u << 8));
p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
template <> inline void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.misalign_offsets = (a_offset << 16) | d_offset;
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
template <> inline void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.misalign_offsets = (a_offset << 16) | d_offset;
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
template <> inline void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.misalign_offsets = (a_offset << 16) | d_offset;
GGML_UNUSED(src0);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
template <> inline void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0));
p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
template <> inline void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.a_offset = a_offset;
p.d_offset = d_offset;
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
template <> inline void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_rope_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
p.a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
p.d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff