Compare commits

..

1 Commits

Author SHA1 Message Date
lhez fdb2c11c70 opencl: support non-contig rows in norm (#24965) 2026-06-24 19:21:25 -07:00
2 changed files with 13 additions and 15 deletions
+8 -13
View File
@@ -10152,14 +10152,8 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int ne00 = src0 ? src0->ne[0] : 0;
const int ne01 = src0 ? src0->ne[1] : 0;
const int ne02 = src0 ? src0->ne[2] : 0;
const int ne03 = src0 ? src0->ne[3] : 0;
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
GGML_TENSOR_LOCALS(int, ne0, src0, ne);
GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
const int nth = MIN(64, ne00);
@@ -10173,11 +10167,12 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &eps));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float)*nth, NULL));
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
+5 -2
View File
@@ -24,6 +24,7 @@ kernel void kernel_norm(
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
@@ -43,7 +44,8 @@ kernel void kernel_norm(
// parallel sum
sum[get_local_id(0)] = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
sum[get_local_id(0)] += x[i00];
// this kernel handles float, nb00/4 translates byte offset to element offset
sum[get_local_id(0)] += x[i00*nb00/4];
}
// reduce
barrier(CLK_LOCAL_MEM_FENCE);
@@ -60,7 +62,8 @@ kernel void kernel_norm(
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
sum[get_local_id(0)] = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
y[i00] = x[i00] - mean;
// this kernel handles float, nb00/4 translates byte offset to element offset
y[i00] = x[i00*nb00/4] - mean;
sum[get_local_id(0)] += y[i00] * y[i00];
}