Compare commits

...

2 Commits

Author SHA1 Message Date
Chenguang Li 3479efd112 CANN: Improve device ID handling and aclnnArange checks (#16752)
* cann: improve device ID handling and aclnnArange checks

- Stop relying on CANN's internal device ID retrieval; use a global variable instead.
- Enforce stricter dimension validation in aclnnArange for better compatibility across CANN versions.

* cann: use thread local var
2025-10-28 10:54:53 +08:00
Aman Gupta 463bbf20bf CUDA: add unused vars to mmvf and mmvq (#16807) 2025-10-28 10:31:21 +08:00
4 changed files with 26 additions and 7 deletions
+2 -2
View File
@@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
ACL_MEM_MALLOC_HUGE_FIRST));
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
theta_scale_ne, theta_scale_nb, 1);
float start = 0;
float step = 1;
@@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
theta_scale_nb, GGML_MAX_DIMS);
theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
+16 -5
View File
@@ -67,19 +67,30 @@
GGML_ABORT("CANN error");
}
// Thread-local variable to record the current device of this thread.
thread_local int g_current_cann_device = -1;
/**
* @brief Sets the device to be used by CANN.
* @brief Set the CANN device to be used.
*
* @param device The device ID to set.
* @param device The target device ID to set.
*/
void ggml_cann_set_device(const int32_t device) {
int current_device = -1;
aclrtGetDevice(&current_device);
// int current_device = -1;
// Note: In some CANN versions, if no device has been set yet,
// aclrtGetDevice(&current_device) may return 0 by default.
// aclrtGetDevice(&current_device);
if (device == current_device) {
// If the current device is already the target one, no need to switch.
if (device == g_current_cann_device) {
return;
}
// Switch to the new device.
ACL_CHECK(aclrtSetDevice(device));
// Update the global device record.
g_current_cann_device = device;
}
/**
+4
View File
@@ -343,6 +343,10 @@ static __global__ void mul_mat_vec_f(
}
dst[tid*stride_col_dst + row] = value;
if constexpr (!has_fusion) {
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
}
}
template<typename T, typename type_acc, int ncols_dst, int block_size>
+4
View File
@@ -310,6 +310,10 @@ static __global__ void mul_mat_vec_q(
dst[j*stride_col_dst + threadIdx.x] = result;
}
}
if constexpr (!has_fusion) {
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);
}
}
static std::pair<dim3, dim3> calc_launch_params(