diff --git a/windows/hid.c b/windows/hid.c index 35c2de04..782b5dbf 100644 --- a/windows/hid.c +++ b/windows/hid.c @@ -78,6 +78,16 @@ typedef LONG NTSTATUS; #define MAX_STRING_WCHARS_USB 126 +#if defined(__GNUC__) +#define thread_local __thread +#elif __STDC_VERSION__ >= 201112L +#define thread_local _Thread_local +#elif defined(_MSC_VER) +#define thread_local __declspec(thread) +#else +#error Cannot define thread_local +#endif + static struct hid_api_version api_version = { .major = HID_API_VERSION_MAJOR, .minor = HID_API_VERSION_MINOR, @@ -180,6 +190,37 @@ static int lookup_functions() #endif /* HIDAPI_USE_DDK */ +typedef void (*tls_destructor)(void **data, hid_device *dev, BOOLEAN all_devices); + +struct tls_allocation { + void *data; + DWORD thread_id; + tls_destructor destructor; + + struct tls_allocation *next; +}; + +struct device_error { + HANDLE device_handle; + wchar_t *last_error_str; + + struct device_error *next; +}; + +struct tls_context { + struct tls_allocation *allocated; + CRITICAL_SECTION critical_section; + BOOLEAN critical_section_ready; +}; + +static struct tls_context tls_context = { + .allocated = NULL, + .critical_section_ready = FALSE +}; + +// Use a NULL device handle for the global error +static thread_local struct device_error *device_error = NULL; + struct hid_device_ { HANDLE device_handle; BOOL blocking; @@ -188,7 +229,6 @@ struct hid_device_ { size_t input_report_length; USHORT feature_report_length; unsigned char *feature_buf; - wchar_t *last_error_str; BOOL read_pending; char *read_buf; OVERLAPPED ol; @@ -211,7 +251,6 @@ static hid_device *new_hid_device() dev->input_report_length = 0; dev->feature_report_length = 0; dev->feature_buf = NULL; - dev->last_error_str = NULL; dev->read_pending = FALSE; dev->read_buf = NULL; memset(&dev->ol, 0, sizeof(dev->ol)); @@ -223,13 +262,151 @@ static hid_device *new_hid_device() return dev; } +static void tls_init_context() +{ + if (!tls_context.critical_section_ready) { + InitializeCriticalSection(&tls_context.critical_section); + tls_context.critical_section_ready = TRUE; + } +} + +static void tls_exit_context() +{ + if (tls_context.critical_section_ready) { + DeleteCriticalSection(&tls_context.critical_section); + tls_context.critical_section_ready = FALSE; + } +} + +static void tls_register(void* data, tls_destructor destructor) +{ + if (!tls_context.critical_section_ready) { + return; + } + + DWORD thread_id = GetCurrentThreadId(); + + EnterCriticalSection(&tls_context.critical_section); + + struct tls_allocation *current = tls_context.allocated; + struct tls_allocation *prev = NULL; + + while (current) { + prev = current; + current = current->next; + } + + struct tls_allocation *tls = (struct tls_allocation*) malloc(sizeof(struct tls_allocation)); + tls->data = data; + tls->thread_id = thread_id; + tls->destructor = destructor; + tls->next = NULL; + + if (prev) { + prev->next = tls; + } + else { + tls_context.allocated = tls; + } + + LeaveCriticalSection(&tls_context.critical_section); +} + +static void tls_free(DWORD thread_id, hid_device *dev, BOOLEAN all_devices) +{ + if (!tls_context.critical_section_ready) { + return; + } + + EnterCriticalSection(&tls_context.critical_section); + + struct tls_allocation *current = tls_context.allocated; + struct tls_allocation *prev = NULL; + + while (current) { + if (thread_id != 0 && thread_id != current->thread_id) { + prev = current; + current = current->next; + continue; + } + + current->destructor(¤t->data, dev, all_devices); + + if (current->data == NULL) { + if (prev) { + prev->next = current->next; + } + else { + tls_context.allocated = current->next; + } + + struct tls_allocation *current_tmp = current; + current = current->next; + free(current_tmp); + } + else { + prev = current; + current = current->next; + } + } + + LeaveCriticalSection(&tls_context.critical_section); +} + +static void tls_free_all_threads(hid_device *dev, BOOLEAN all_devices) +{ + tls_free(0, dev, all_devices); +} + +static void free_error_buffer(struct device_error **error, hid_device *dev, BOOLEAN all_devices) +{ + if (error == NULL) { + return; + } + + struct device_error *current = *error; + + if (all_devices) { + while (current) { + struct device_error *current_tmp = current; + current = current->next; + free(current_tmp->last_error_str); + free(current_tmp); + } + + *error = NULL; + } + else + { + struct device_error *prev = NULL; + + while (current) { + if ((dev == NULL && current->device_handle == NULL) || + (dev != NULL && dev->device_handle == current->device_handle)) { + if (prev) { + prev->next = current->next; + } + else { + *error = current->next; + } + + free(current->last_error_str); + free(current); + break; + } + + prev = current; + current = current->next; + } + } +} + static void free_hid_device(hid_device *dev) { CloseHandle(dev->ol.hEvent); CloseHandle(dev->write_ol.hEvent); CloseHandle(dev->device_handle); - free(dev->last_error_str); - dev->last_error_str = NULL; + tls_free_all_threads(dev, FALSE); free(dev->write_buf); free(dev->feature_buf); free(dev->read_buf); @@ -316,26 +493,75 @@ static void register_string_error_to_buffer(wchar_t **error_buffer, const WCHAR # pragma GCC diagnostic pop #endif +static wchar_t** get_error_buffer(hid_device *dev) +{ + struct device_error *current = device_error; + struct device_error *prev = NULL; + + while (current) { + if ((dev == NULL && current->device_handle == NULL) || + (dev != NULL && dev->device_handle == current->device_handle)) { + return ¤t->last_error_str; + } + + prev = current; + current = current->next; + } + + struct device_error *error = (struct device_error*) malloc(sizeof(struct device_error)); + error->device_handle = dev != NULL ? dev->device_handle : NULL; + error->last_error_str = NULL; + error->next = NULL; + + if (prev) { + prev->next = error; + } + else { + device_error = error; + tls_register(device_error, (tls_destructor)&free_error_buffer); + } + + return &error->last_error_str; +} + +static wchar_t* get_error_str(hid_device *dev) +{ + struct device_error *current = device_error; + + while (current) { + if ((dev == NULL && current->device_handle == NULL) || + (dev != NULL && dev->device_handle == current->device_handle)) { + return current->last_error_str; + } + + current = current->next; + } + + return NULL; +} + static void register_winapi_error(hid_device *dev, const WCHAR *op) { - register_winapi_error_to_buffer(&dev->last_error_str, op); + wchar_t **error_buffer = get_error_buffer(dev); + register_winapi_error_to_buffer(error_buffer, op); } static void register_string_error(hid_device *dev, const WCHAR *string_error) { - register_string_error_to_buffer(&dev->last_error_str, string_error); + wchar_t **error_buffer = get_error_buffer(dev); + register_string_error_to_buffer(error_buffer, string_error); } -static wchar_t *last_global_error_str = NULL; - static void register_global_winapi_error(const WCHAR *op) { - register_winapi_error_to_buffer(&last_global_error_str, op); + wchar_t **error_buffer = get_error_buffer(NULL); + register_winapi_error_to_buffer(error_buffer, op); } static void register_global_error(const WCHAR *string_error) { - register_string_error_to_buffer(&last_global_error_str, string_error); + wchar_t **error_buffer = get_error_buffer(NULL); + register_string_error_to_buffer(error_buffer, string_error); } static HANDLE open_device(const wchar_t *path, BOOL open_rw) @@ -365,6 +591,29 @@ HID_API_EXPORT const char* HID_API_CALL hid_version_str(void) return HID_API_VERSION_STR; } +BOOL WINAPI DllMain(HINSTANCE instance, DWORD reason, LPVOID reserved) +{ + (void)instance; + (void)reserved; + + switch (reason) { + case DLL_PROCESS_ATTACH: + tls_init_context(); + break; + + case DLL_PROCESS_DETACH: + tls_exit_context(); + break; + + case DLL_THREAD_DETACH: { + DWORD thread_id = GetCurrentThreadId(); + tls_free(thread_id, NULL, TRUE); + break; + } + } + return TRUE; +} + int HID_API_EXPORT hid_init(void) { register_global_error(NULL); @@ -386,7 +635,7 @@ int HID_API_EXPORT hid_exit(void) free_library_handles(); hidapi_initialized = FALSE; #endif - register_global_error(NULL); + tls_free_all_threads(NULL, TRUE); return 0; } @@ -1529,15 +1778,10 @@ int HID_API_EXPORT_CALL hid_get_report_descriptor(hid_device *dev, unsigned char HID_API_EXPORT const wchar_t * HID_API_CALL hid_error(hid_device *dev) { - if (dev) { - if (dev->last_error_str == NULL) - return L"Success"; - return (wchar_t*)dev->last_error_str; - } - - if (last_global_error_str == NULL) + wchar_t *error_str = get_error_str(dev); + if (error_str == NULL) return L"Success"; - return last_global_error_str; + return error_str; } #ifndef hidapi_winapi_EXPORTS