diff --git a/src/libthecore/fdwatch.cpp b/src/libthecore/fdwatch.cpp index d03fb49..fc37de6 100644 --- a/src/libthecore/fdwatch.cpp +++ b/src/libthecore/fdwatch.cpp @@ -47,6 +47,21 @@ static void win32_deinit() } #endif +#if defined(__linux__) && !defined(OS_WINDOWS) +#ifndef EPOLLRDHUP +#define EPOLLRDHUP 0 +#endif + +struct FdwatchSelectState +{ + int epoll_fd; + epoll_event* epoll_events; + int nepoll_events; + void** fd_data; + int* fd_rw; +}; +#else + struct FdwatchSelectState { fd_set rfd_set; @@ -64,6 +79,8 @@ struct FdwatchSelectState #endif +#endif + struct fdwatch { EFdwatchBackend backend; @@ -78,7 +95,9 @@ struct fdwatch static EFdwatchBackend fdwatch_default_backend() { -#ifndef __USE_SELECT__ +#if defined(__linux__) && !defined(OS_WINDOWS) + return FDWATCH_BACKEND_EPOLL; +#elif !defined(__USE_SELECT__) return FDWATCH_BACKEND_KQUEUE; #else return FDWATCH_BACKEND_SELECT; @@ -326,6 +345,258 @@ static void* fdwatch_get_client_data_kqueue(LPFDWATCH fdw, unsigned int event_id #else +#if defined(__linux__) && !defined(OS_WINDOWS) + +static uint32_t fdwatch_epoll_events(int rw) +{ + uint32_t events = 0; + + if (rw & FDW_READ) + events |= EPOLLIN | EPOLLRDHUP; + + if (rw & FDW_WRITE) + events |= EPOLLOUT; + + return events; +} + +static bool fdwatch_update_epoll_interest(LPFDWATCH fdw, socket_t fd, int old_rw, int new_rw) +{ + epoll_event event {}; + event.events = fdwatch_epoll_events(new_rw); + event.data.fd = fd; + + int op = EPOLL_CTL_MOD; + if (!(old_rw & (FDW_READ | FDW_WRITE))) + op = EPOLL_CTL_ADD; + else if (!(new_rw & (FDW_READ | FDW_WRITE))) + op = EPOLL_CTL_DEL; + + if (epoll_ctl(fdw->state.epoll_fd, op, fd, op == EPOLL_CTL_DEL ? nullptr : &event) == 0) + return true; + + if (op == EPOLL_CTL_DEL && (errno == ENOENT || errno == EBADF)) + return true; + + sys_err("epoll_ctl(%d, fd=%d) failed: %s", op, fd, strerror(errno)); + return false; +} + +static LPFDWATCH fdwatch_new_select(int nfiles) +{ + LPFDWATCH fdw; + +#ifdef EPOLL_CLOEXEC + const int epoll_fd = epoll_create1(EPOLL_CLOEXEC); +#else + const int epoll_fd = epoll_create(nfiles); +#endif + + if (epoll_fd == -1) + { + sys_err("epoll_create failed: %s", strerror(errno)); + return NULL; + } + + CREATE(fdw, FDWATCH, 1); + fdw->backend = FDWATCH_BACKEND_EPOLL; + fdw->descriptor_limit = nfiles; + fdw->state.epoll_fd = epoll_fd; + fdw->state.nepoll_events = 0; + + CREATE(fdw->state.epoll_events, epoll_event, fdw->descriptor_limit); + CREATE(fdw->state.fd_rw, int, fdw->descriptor_limit); + CREATE(fdw->state.fd_data, void*, fdw->descriptor_limit); + + return fdw; +} + +static void fdwatch_delete_select(LPFDWATCH fdw) +{ + close(fdw->state.epoll_fd); + free(fdw->state.epoll_events); + free(fdw->state.fd_data); + free(fdw->state.fd_rw); + free(fdw); +} + +static int fdwatch_check_event_select(LPFDWATCH fdw, socket_t fd, unsigned int event_idx); + +static int fdwatch_check_fd_select(LPFDWATCH fdw, socket_t fd) +{ + for (int i = 0; i < fdw->state.nepoll_events; ++i) + { + if (fdw->state.epoll_events[i].data.fd != fd) + continue; + + if (fdwatch_check_event_select(fdw, fd, i) != 0) + return 1; + } + + return 0; +} + +static void fdwatch_clear_fd_select(LPFDWATCH fdw, socket_t fd) +{ + if (fd < 0 || fd >= fdw->descriptor_limit) + return; + + const int old_rw = fdw->state.fd_rw[fd]; + fdw->state.fd_rw[fd] = 0; + fdw->state.fd_data[fd] = NULL; + fdwatch_update_epoll_interest(fdw, fd, old_rw, 0); + + for (int i = 0; i < fdw->state.nepoll_events; ++i) + { + if (fdw->state.epoll_events[i].data.fd == fd) + { + fdw->state.epoll_events[i].data.fd = -1; + fdw->state.epoll_events[i].events = 0; + } + } +} + +static void fdwatch_add_fd_select(LPFDWATCH fdw, socket_t fd, void* client_data, int rw, int oneshot) +{ + if (fd < 0 || fd >= fdw->descriptor_limit) + { + sys_err("fdwatch_add_fd_epoll: descriptor %d exceeds backend limit %d", fd, fdw->descriptor_limit); + return; + } + + const int old_rw = fdw->state.fd_rw[fd]; + int new_rw = old_rw | rw; + + if (oneshot && (rw & FDW_WRITE)) + new_rw |= FDW_WRITE_ONESHOT; + + fdw->state.fd_data[fd] = client_data; + + if (fdwatch_update_epoll_interest(fdw, fd, old_rw, new_rw)) + fdw->state.fd_rw[fd] = new_rw; +} + +static void fdwatch_del_fd_select(LPFDWATCH fdw, socket_t fd) +{ + if (fd < 0 || fd >= fdw->descriptor_limit) + return; + + const int old_rw = fdw->state.fd_rw[fd]; + if (!(old_rw & (FDW_READ | FDW_WRITE | FDW_WRITE_ONESHOT)) && !fdw->state.fd_data[fd]) + return; + + fdw->state.fd_rw[fd] = 0; + fdw->state.fd_data[fd] = NULL; + fdwatch_update_epoll_interest(fdw, fd, old_rw, 0); + + for (int i = 0; i < fdw->state.nepoll_events; ++i) + { + if (fdw->state.epoll_events[i].data.fd == fd) + { + fdw->state.epoll_events[i].data.fd = -1; + fdw->state.epoll_events[i].events = 0; + } + } +} + +static int fdwatch_select(LPFDWATCH fdw, struct timeval* timeout) +{ + int timeout_ms = 0; + + if (timeout) + timeout_ms = static_cast(timeout->tv_sec * 1000 + timeout->tv_usec / 1000); + + const int result = epoll_wait(fdw->state.epoll_fd, fdw->state.epoll_events, fdw->descriptor_limit, timeout_ms); + + if (result == -1) + { + if (errno == EINTR) + return 0; + + return -1; + } + + fdw->state.nepoll_events = result; + return result; +} + +static void* fdwatch_get_client_data_select(LPFDWATCH fdw, unsigned int event_idx) +{ + if (event_idx >= static_cast(fdw->state.nepoll_events)) + return NULL; + + const int fd = fdw->state.epoll_events[event_idx].data.fd; + + if (fd < 0 || fd >= fdw->descriptor_limit) + return NULL; + + return fdw->state.fd_data[fd]; +} + +static int fdwatch_get_ident_select(LPFDWATCH fdw, unsigned int event_idx) +{ + if (event_idx >= static_cast(fdw->state.nepoll_events)) + return 0; + + return fdw->state.epoll_events[event_idx].data.fd; +} + +static void fdwatch_clear_event_select(LPFDWATCH fdw, socket_t fd, unsigned int event_idx) +{ + if (event_idx >= static_cast(fdw->state.nepoll_events)) + return; + + if (fdw->state.epoll_events[event_idx].data.fd != fd) + return; + + fdw->state.epoll_events[event_idx].data.fd = -1; + fdw->state.epoll_events[event_idx].events = 0; +} + +static int fdwatch_check_event_select(LPFDWATCH fdw, socket_t fd, unsigned int event_idx) +{ + if (event_idx >= static_cast(fdw->state.nepoll_events)) + return 0; + + const epoll_event& event = fdw->state.epoll_events[event_idx]; + if (event.data.fd != fd) + return 0; + + if (fd < 0 || fd >= fdw->descriptor_limit) + return 0; + + const int fd_rw = fdw->state.fd_rw[fd]; + + if ((fd_rw & FDW_READ) && (event.events & (EPOLLIN | EPOLLPRI))) + return FDW_READ; + + if ((fd_rw & FDW_WRITE) && (event.events & EPOLLOUT)) + { + if (fd_rw & FDW_WRITE_ONESHOT) + { + const int new_rw = fd_rw & ~(FDW_WRITE | FDW_WRITE_ONESHOT); + if (fdwatch_update_epoll_interest(fdw, fd, fd_rw, new_rw)) + fdw->state.fd_rw[fd] = new_rw; + } + + return FDW_WRITE; + } + + if (event.events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP)) + return FDW_EOF; + + return 0; +} + +static int fdwatch_get_buffer_size_select(LPFDWATCH fdw, socket_t fd) +{ + (void)fdw; + (void)fd; + return INT_MAX; +} + +#else + static LPFDWATCH fdwatch_new_select(int nfiles) { LPFDWATCH fdw; @@ -617,6 +888,8 @@ static int fdwatch_get_buffer_size_select(LPFDWATCH fdw, socket_t fd) return INT_MAX; // XXX TODO } +#endif + #endif EFdwatchBackend fdwatch_get_backend(LPFDWATCH fdw) { @@ -634,6 +907,8 @@ const char* fdwatch_backend_name(EFdwatchBackend backend) return "kqueue"; case FDWATCH_BACKEND_SELECT: return "select"; + case FDWATCH_BACKEND_EPOLL: + return "epoll"; default: return "unknown"; } diff --git a/src/libthecore/fdwatch.h b/src/libthecore/fdwatch.h index 7075a78..943b8c6 100644 --- a/src/libthecore/fdwatch.h +++ b/src/libthecore/fdwatch.h @@ -16,6 +16,7 @@ enum EFdwatchBackend { FDWATCH_BACKEND_KQUEUE = 0, FDWATCH_BACKEND_SELECT = 1, + FDWATCH_BACKEND_EPOLL = 2, }; LPFDWATCH fdwatch_new(int nfiles); diff --git a/src/libthecore/stdafx.h b/src/libthecore/stdafx.h index 8a2def2..d0da494 100644 --- a/src/libthecore/stdafx.h +++ b/src/libthecore/stdafx.h @@ -102,6 +102,8 @@ inline double rint(double x) #ifdef OS_FREEBSD #include +#elif defined(__linux__) +#include #endif #endif diff --git a/tests/smoke_auth.cpp b/tests/smoke_auth.cpp index 7d56048..829d2eb 100644 --- a/tests/smoke_auth.cpp +++ b/tests/smoke_auth.cpp @@ -297,7 +297,11 @@ void TestFdwatchBackendMetadata() LPFDWATCH fdw = fdwatch_new(4096); Expect(fdw != nullptr, "fdwatch_new for backend metadata failed"); -#ifdef __USE_SELECT__ +#ifdef __linux__ + Expect(fdwatch_get_backend(fdw) == FDWATCH_BACKEND_EPOLL, "Expected epoll backend"); + Expect(std::strcmp(fdwatch_backend_name(fdwatch_get_backend(fdw)), "epoll") == 0, "Unexpected epoll backend name"); + Expect(fdwatch_get_descriptor_limit(fdw) == 4096, "Unexpected epoll descriptor limit"); +#elif defined(__USE_SELECT__) Expect(fdwatch_get_backend(fdw) == FDWATCH_BACKEND_SELECT, "Expected select backend"); Expect(std::strcmp(fdwatch_backend_name(fdwatch_get_backend(fdw)), "select") == 0, "Unexpected select backend name"); Expect(fdwatch_get_descriptor_limit(fdw) == std::min(4096, static_cast(FD_SETSIZE)), "Unexpected select descriptor limit");