Compare commits

..

4 Commits

Author SHA1 Message Date
Isaac McFadyen 2bb0467043 rpc : nicer error messages for RPC server crash (#14076) 2025-06-10 09:41:01 +03:00
Georgi Gerganov b8e2194efc sync : ggml
ggml-ci
2025-06-10 09:21:56 +03:00
Kai Pastor 1a3b5e80f7 Add in-build ggml::ggml ALIAS library (ggml/1260)
Enable uniform linking with subproject and with find_package.
2025-06-10 09:21:56 +03:00
Georgi Gerganov 1f63e75f3b metal : use less stack memory in FA kernel (#14088)
* metal : use less stack memory in FA kernel

ggml-ci

* cont : fix BF16 variant
2025-06-09 23:05:02 +03:00
4 changed files with 76 additions and 77 deletions
+1
View File
@@ -212,6 +212,7 @@ endif()
add_library(ggml
ggml-backend-reg.cpp)
add_library(ggml::ggml ALIAS ggml)
target_link_libraries(ggml PUBLIC ggml-base)
+56 -61
View File
@@ -3333,8 +3333,6 @@ kernel void kernel_flash_attn_ext(
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
@@ -3548,20 +3546,20 @@ kernel void kernel_flash_attn_ext(
// O = diag(ms)*O
{
s8x8_t mm;
simdgroup_load(mm, ss + 2*C, TS, 0, false);
s8x8_t ms;
simdgroup_load(ms, ss + 2*C, TS, 0, false);
#pragma unroll(DV8)
for (short i = 0; i < DV8; ++i) {
simdgroup_multiply(lo[i], mm, lo[i]);
simdgroup_multiply(lo[i], ms, lo[i]);
}
}
// O = O + (Q*K^T)*V
{
for (short cc = 0; cc < C/8; ++cc) {
s8x8_t ms;
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
s8x8_t vs;
simdgroup_load(vs, ss + 8*cc, TS, 0, false);
if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory
@@ -3572,7 +3570,7 @@ kernel void kernel_flash_attn_ext(
v8x8_t mv;
simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
}
} else {
for (short ii = 0; ii < DV16; ii += 4) {
@@ -3593,10 +3591,10 @@ kernel void kernel_flash_attn_ext(
v8x8_t mv;
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
}
} else {
if (ii + tx < DV16) {
@@ -3611,10 +3609,10 @@ kernel void kernel_flash_attn_ext(
v8x8_t mv;
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
}
}
}
@@ -3624,83 +3622,80 @@ kernel void kernel_flash_attn_ext(
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = 0; j < Q; ++j) {
if (tiisg == 0) {
ss[j*TS + 0] = S[j];
ss[j*TS + 1] = M[j];
}
for (short j = tiisg; j < Q; j += NW) {
ss[j*TS + 0] = S[j];
ss[j*TS + 1] = M[j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
// store result to shared memory in F32
if (sgitg == 0) {
for (short i = 0; i < DV8; ++i) {
//simdgroup_store(lo[i], so + i*8, DV, 0, false);
simdgroup_float8x8 t(1.0f);
simdgroup_multiply(t, lo[i], t);
simdgroup_store(t, so + i*8, DV, 0, false);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// reduce the warps sequentially
for (ushort sg = 1; sg < nsg; ++sg) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) {
for (short i = 0; i < DV8; ++i) {
simdgroup_store(lo[i], so + i*8, DV, 0, false);
}
}
for (short j = tiisg; j < Q; j += NW) {
const float S0 = ss[j*TS - 1*SH + 0];
const float S1 = ss[j*TS + 0];
threadgroup_barrier(mem_flags::mem_threadgroup);
// the first simdgroup accumulates the results from the other simdgroups
if (sgitg == 0) {
for (short j = 0; j < Q; ++j) {
const float S0 = ss[j*TS + 0];
const float S1 = ss[j*TS + sg*SH + 0];
const float M0 = ss[j*TS + 1];
const float M1 = ss[j*TS + sg*SH + 1];
const float M0 = ss[j*TS - 1*SH + 1];
const float M1 = ss[j*TS + 1];
const float M = max(M0, M1);
const float ms0 = exp(M0 - M);
const float ms1 = exp(M1 - M);
float ms0 = exp(M0 - M);
float ms1 = exp(M1 - M);
const float S = S0*ms0 + S1*ms1;
if (tiisg == 0) {
ss[j*TS + 0] = S;
ss[j*TS + 1] = M;
ss[j*TS + 0] = S;
ss[j*TS + 1] = M;
ss[j*TS + 2*C + j ] = ms0;
ss[j*TS + 2*C + j + sg*SH] = ms1;
}
ss[j*TS + 2*C + j - 1*SH] = ms0;
ss[j*TS + 2*C + j ] = ms1;
}
//simdgroup_barrier(mem_flags::mem_threadgroup);
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
{
s8x8_t ms0;
s8x8_t ms1;
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
simdgroup_load(ms1, ss + 2*C, TS, 0, false);
#pragma unroll(DV8)
for (short i = 0; i < DV8; ++i) {
o8x8_t t;
simdgroup_float8x8 t;
simdgroup_load (t, so + i*8, DV, 0, false);
simdgroup_multiply(t, ms1, t);
simdgroup_multiply(t, ms0, t);
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
simdgroup_multiply_accumulate(t, ms1, lo[i], t);
simdgroup_store(t, so + i*8, DV, 0, false);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// store result to shared memory (reuse sq)
if (sgitg == 0) {
for (short i = 0; i < DV8; ++i) {
simdgroup_store(lo[i], so + i*8, DV, 0, false);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
// final rescale with 1/S and store to global memory
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
@@ -3723,8 +3718,8 @@ kernel void kernel_flash_attn_ext(
half, half4x4, simdgroup_half8x8, \
float, simdgroup_float8x8, \
float, simdgroup_float8x8, \
float, float4, simdgroup_float8x8
//half, half4, simdgroup_half8x8
half, half4, simdgroup_half8x8
//float, float4, simdgroup_float8x8
#define FA_TYPES_BF \
bfloat, bfloat4, simdgroup_bfloat8x8, \
@@ -3732,8 +3727,8 @@ kernel void kernel_flash_attn_ext(
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
float, simdgroup_float8x8, \
float, simdgroup_float8x8, \
float, float4, simdgroup_float8x8
//half, half4, simdgroup_half8x8
half, half4, simdgroup_half8x8
//float, float4, simdgroup_float8x8
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
+18 -15
View File
@@ -53,6 +53,9 @@ struct socket_t {
}
};
// macro for nicer error messages on server crash
#define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response")
// all RPC structures must be packed
#pragma pack(push, 1)
// ggml_tensor is serialized into rpc_tensor
@@ -425,7 +428,7 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
rpc_msg_hello_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
return false;
@@ -481,7 +484,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_free_buffer_req request = {ctx->remote_ptr};
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
delete ctx;
}
@@ -493,7 +496,7 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
rpc_msg_buffer_get_base_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
return ctx->base_ptr;
}
@@ -545,7 +548,7 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
request.tensor = serialize_tensor(tensor);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
}
return GGML_STATUS_SUCCESS;
}
@@ -560,7 +563,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
request.hash = fnv_hash((const uint8_t*)data, size);
rpc_msg_set_tensor_hash_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
if (response.result) {
// the server has the same data, no need to send it
return;
@@ -573,7 +576,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
}
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -583,7 +586,7 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
request.offset = offset;
request.size = size;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
}
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@@ -601,7 +604,7 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
request.dst = serialize_tensor(dst);
rpc_msg_copy_tensor_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
return response.result;
}
@@ -609,7 +612,7 @@ static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
}
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
@@ -635,7 +638,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
rpc_msg_alloc_buffer_rsp response;
auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
if (response.remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
@@ -650,7 +653,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
rpc_msg_get_alignment_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
return response.alignment;
}
@@ -662,7 +665,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
rpc_msg_get_max_size_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
return response.max_size;
}
@@ -683,7 +686,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
rpc_msg_get_alloc_size_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
return response.alloc_size;
} else {
@@ -761,7 +764,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
rpc_msg_graph_compute_rsp response;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
return (enum ggml_status)response.result;
}
@@ -835,7 +838,7 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
rpc_msg_get_device_memory_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
GGML_ASSERT(status);
RPC_STATUS_ASSERT(status);
*free = response.free_mem;
*total = response.total_mem;
}
+1 -1
View File
@@ -1 +1 @@
94a83ba5a725ae2aee79df75dd99b2119d0478cc
a1761cd64ce4a75a9770a954a013422c7910db8b