Compare commits

...

2 Commits

Author SHA1 Message Date
Neo Zhang f8cc15f163 [SYCL] support bf16 on bin_bcast OP and unary OPs (#24838)
* support bf16 on bin_bcast OP and unary OPs

* support the older Intel compiler than 2026.0
2026-06-22 14:09:02 +03:00
Tim Neumann 37957e8531 sampling : remove unconditional softmax+sort in top-n-sigma sampler (#22645) 2026-06-22 14:08:32 +03:00
4 changed files with 162 additions and 57 deletions
+5
View File
@@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) {
op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data,
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
#endif
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
+155 -53
View File
@@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) {
return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
}
template<typename T>
static __dpct_inline__ T op_abs(T x) {
return sycl::fabs(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed
} else {
return sycl::fabs(x);
}
}
template<typename T>
static __dpct_inline__ T op_expm1(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return static_cast<sycl::ext::oneapi::bfloat16>(
sycl::expm1(static_cast<float>(x))
);
} else {
return sycl::expm1(x);
}
}
template<typename T>
static __dpct_inline__ T op_elu(T x) {
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
return (x > static_cast<T>(0.f)) ? x : op_expm1(x);
}
template<typename T>
static __dpct_inline__ T op_tanh(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
constexpr int ver = __INTEL_LLVM_COMPILER;
#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000)
return sycl::ext::oneapi::experimental::tanh(x);
#else
return static_cast<T>(sycl::tanh(static_cast<float>(x)));
#endif
} else {
return sycl::tanh(x);
}
}
template<typename T>
@@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) {
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
return static_cast<T>(0.5f) * x *
(static_cast<T>(1.0f) +
sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
op_tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
}
template<typename T>
static __dpct_inline__ T op_exp(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::exp(x);
} else {
return sycl::exp(x);
}
}
template<typename T>
static __dpct_inline__ T op_silu(T x) {
return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
return x / (static_cast<T>(1.0f) + op_exp(-x));
}
template<typename T>
static __dpct_inline__ T op_gelu_quick(T x) {
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
static __dpct_inline__ T op_erf(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return static_cast<sycl::ext::oneapi::bfloat16>(
sycl::erf(static_cast<float>(x))
);
} else {
return sycl::erf(x);
}
}
template<typename T>
static __dpct_inline__ T op_gelu_erf(T x) {
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + op_erf(x * SQRT_2_INV));
}
template<typename T>
static __dpct_inline__ T op_tanh(T x) {
return sycl::tanh(x);
static __dpct_inline__ T op_gelu_quick(T x) {
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x)));
}
template<typename T>
static __dpct_inline__ T op_relu(T x) {
return sycl::fmax(x, static_cast<T>(0));
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0));
} else {
return sycl::fmax(x, static_cast<T>(0));
}
}
template<typename T>
static __dpct_inline__ T op_sigmoid(T x) {
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(-x));
}
template<typename T>
static __dpct_inline__ T op_sqrt(T x) {
return sycl::sqrt(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::sqrt(x);
} else {
return sycl::sqrt(x);
}
}
template<typename T>
static __dpct_inline__ T op_sin(T x) {
return sycl::sin(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::sin(x);
} else {
return sycl::sin(x);
}
}
template<typename T>
static __dpct_inline__ T op_cos(T x) {
return sycl::cos(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::cos(x);
} else {
return sycl::cos(x);
}
}
template<typename T>
static __dpct_inline__ T op_hardsigmoid(T x) {
return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fmin(
static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(
static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
} else {
return sycl::fmin(static_cast<T>(1.0f),
sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
}
}
template<typename T>
static __dpct_inline__ T op_hardswish(T x) {
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
}
template<typename T>
static __dpct_inline__ T op_exp(T x) {
return sycl::exp(x);
}
template<typename T>
static __dpct_inline__ T op_expm1(T x) {
return sycl::expm1(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return x * sycl::ext::oneapi::experimental::fmin(static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
} else {
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
}
}
template<typename T>
@@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) {
if (x <= static_cast<T>(0)) {
return neg_infinity<T>();
}
return sycl::log(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::log(x);
} else {
return sycl::log(x);
}
}
template<typename T>
static __dpct_inline__ T op_softplus(T x) {
const float xf = (float) x;
const float ax = sycl::fabs(xf);
const float ax = op_abs(xf);
const float m = sycl::fmax(xf, 0.0f);
const float y = m + sycl::log1p(sycl::exp(-ax));
return (T) y;
@@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) {
template<typename T>
static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
T neg_slope_T = static_cast<T>(negative_slope);
return sycl::fmax(x, static_cast<T>(0)) +
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0)) +
sycl::ext::oneapi::experimental::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
} else {
return sycl::fmax(x, static_cast<T>(0)) +
sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
}
}
template<typename T>
@@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
template<typename T>
static __dpct_inline__ T op_floor(T x) {
return sycl::floor(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::floor(x);
} else {
return sycl::floor(x);
}
}
template<typename T>
static __dpct_inline__ T op_ceil(T x) {
return sycl::ceil(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::ceil(x);
} else {
return sycl::ceil(x);
}
}
template<typename T>
static __dpct_inline__ T op_round(T x) {
return sycl::round(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return static_cast<sycl::ext::oneapi::bfloat16>(
sycl::round(static_cast<float>(x))
);
} else {
return sycl::round(x);
}
}
template<typename T>
static __dpct_inline__ T op_trunc(T x) {
return sycl::trunc(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::trunc(x);
} else {
return sycl::trunc(x);
}
}
template<typename T, typename F>
@@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<3> /*item_ct1*/) {
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
});
}
@@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step,
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16);
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
@@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
break;
}
#ifdef GGML_SYCL_HAS_BF16
case GGML_TYPE_BF16:
{
auto data_pts = cast_data<sycl::ext::oneapi::bfloat16>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
@@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary(
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_generic_kernel(
src, dst_ptr, k_elements,
ne0, ne1, ne2, ne3,
@@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
arange_kernel(dst_ptr, k, start, step, item_ct1);
});
}
@@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
});
}, negative_slope);
@@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
});
}, min_val, max_val);
@@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp
return out_glu;
}
template <typename T>
static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
const int64_t n, const int64_t o0, const int64_t o1,
@@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x,
const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
});
}
@@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
-2
View File
@@ -2813,8 +2813,6 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
cur_p->data[i].logit = -INFINITY;
}
}
llama_sampler_softmax_impl(cur_p, true);
}
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
+2 -2
View File
@@ -360,9 +360,9 @@ int main(void) {
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.032727f, 0.241818f, 0.241818f}, 2.0f, 1.1f, 2, 5, {});
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.0f, 0.0f, 0.428571f, 0.571429f}, 1.00f);
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 3.00f);
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);