From e0ebc431dc78e3eee04400c0619162f94f6476ab Mon Sep 17 00:00:00 2001
From: yhirose <yuji.hirose.bug@gmail.com>
Date: Fri, 11 Oct 2024 13:29:55 -0400
Subject: [PATCH] Fix #1959

---
 httplib.h    | 42 ++++++++++++++++++++++++++++++++++++++++--
 test/test.cc | 35 +++++++++++++++++++++++++++++++++++
 2 files changed, 75 insertions(+), 2 deletions(-)

diff --git a/httplib.h b/httplib.h
index 121ce34..99112b9 100644
--- a/httplib.h
+++ b/httplib.h
@@ -18,6 +18,10 @@
 #define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5
 #endif
 
+#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND
+#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND 10000
+#endif
+
 #ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT
 #define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 100
 #endif
@@ -3251,6 +3255,41 @@ private:
 };
 #endif
 
+inline bool keep_alive(const std::atomic<socket_t> &svr_sock, socket_t sock,
+                       time_t keep_alive_timeout_sec) {
+  using namespace std::chrono;
+
+  const auto interval_usec =
+      CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND;
+
+  // Avoid expensive `steady_clock::now()` call for the first time
+  if (select_read(sock, 0, interval_usec) > 0) { return true; }
+
+  const auto start = steady_clock::now() - microseconds{interval_usec};
+  const auto timeout = seconds{keep_alive_timeout_sec};
+
+  while (true) {
+    if (svr_sock == INVALID_SOCKET) {
+      break; // Server socket is closed
+    }
+
+    auto val = select_read(sock, 0, interval_usec);
+    if (val < 0) {
+      break; // Ssocket error
+    } else if (val == 0) {
+      if (steady_clock::now() - start > timeout) {
+        break; // Timeout
+      }
+    } else {
+      return true; // Ready for read
+    }
+
+    std::this_thread::sleep_for(microseconds{interval_usec});
+  }
+
+  return false;
+}
+
 template <typename T>
 inline bool
 process_server_socket_core(const std::atomic<socket_t> &svr_sock, socket_t sock,
@@ -3259,8 +3298,7 @@ process_server_socket_core(const std::atomic<socket_t> &svr_sock, socket_t sock,
   assert(keep_alive_max_count > 0);
   auto ret = false;
   auto count = keep_alive_max_count;
-  while (svr_sock != INVALID_SOCKET && count > 0 &&
-         select_read(sock, keep_alive_timeout_sec, 0) > 0) {
+  while (count > 0 && keep_alive(svr_sock, sock, keep_alive_timeout_sec)) {
     auto close_connection = count == 1;
     auto connection_closed = false;
     ret = callback(close_connection, connection_closed);
diff --git a/test/test.cc b/test/test.cc
index 76c6f60..0cd450e 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -5268,6 +5268,41 @@ TEST(KeepAliveTest, Issue1041) {
   EXPECT_EQ(StatusCode::OK_200, result->status);
 }
 
+TEST(KeepAliveTest, Issue1959) {
+  Server svr;
+  svr.set_keep_alive_timeout(5);
+
+  svr.Get("/a", [&](const Request & /*req*/, Response &res) {
+    res.set_content("a", "text/plain");
+  });
+
+  auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); });
+  auto se = detail::scope_exit([&] {
+    if (!svr.is_running()) return;
+    svr.stop();
+    listen_thread.join();
+    ASSERT_FALSE(svr.is_running());
+  });
+
+  svr.wait_until_ready();
+
+  Client cli("localhost", PORT);
+  cli.set_keep_alive(true);
+
+  using namespace std::chrono;
+  auto start = steady_clock::now();
+
+  cli.Get("/a");
+
+  svr.stop();
+  listen_thread.join();
+
+  auto end = steady_clock::now();
+  auto elapsed = duration_cast<milliseconds>(end - start).count();
+
+  EXPECT_LT(elapsed, 5000);
+}
+
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 TEST(KeepAliveTest, SSLClientReconnection) {
   SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);