From 521529d24d3815777fb62ed34328daf1c7ca75ff Mon Sep 17 00:00:00 2001
From: yhirose <yhirose@users.noreply.github.com>
Date: Tue, 6 Aug 2024 13:43:00 -0400
Subject: [PATCH] Fix #1481 (with content provider) (#1527)

* Fix #1481 (with content provider)

* Improve shutdown performance

* Make shutdown action more stable

* Move some tests up

* Simplified

* Simplified
---
 httplib.h    |  25 +++-
 test/test.cc | 374 +++++++++++++++++++++++++++++----------------------
 2 files changed, 235 insertions(+), 164 deletions(-)

diff --git a/httplib.h b/httplib.h
index 2aefc4b..d6d6541 100644
--- a/httplib.h
+++ b/httplib.h
@@ -8541,13 +8541,29 @@ inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex,
   return ssl;
 }
 
-inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl,
+inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock,
                        bool shutdown_gracefully) {
   // sometimes we may want to skip this to try to avoid SIGPIPE if we know
   // the remote has closed the network connection
   // Note that it is not always possible to avoid SIGPIPE, this is merely a
   // best-efforts.
-  if (shutdown_gracefully) { SSL_shutdown(ssl); }
+  if (shutdown_gracefully) {
+#ifdef _WIN32
+    SSL_shutdown(ssl);
+#else
+    timeval tv;
+    tv.tv_sec = 1;
+    tv.tv_usec = 0;
+    setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO,
+               reinterpret_cast<const void *>(&tv), sizeof(tv));
+
+    auto ret = SSL_shutdown(ssl);
+    while (ret == 0) {
+      std::this_thread::sleep_for(std::chrono::milliseconds(100));
+      ret = SSL_shutdown(ssl);
+    }
+#endif
+  }
 
   std::lock_guard<std::mutex> guard(ctx_mutex);
   SSL_free(ssl);
@@ -8826,7 +8842,7 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) {
     // Shutdown gracefully if the result seemed successful, non-gracefully if
     // the connection appeared to be closed.
     const bool shutdown_gracefully = ret;
-    detail::ssl_delete(ctx_mutex_, ssl, shutdown_gracefully);
+    detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully);
   }
 
   detail::shutdown_socket(sock);
@@ -9109,7 +9125,8 @@ inline void SSLClient::shutdown_ssl_impl(Socket &socket,
     return;
   }
   if (socket.ssl) {
-    detail::ssl_delete(ctx_mutex_, socket.ssl, shutdown_gracefully);
+    detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock,
+                       shutdown_gracefully);
     socket.ssl = nullptr;
   }
   assert(socket.ssl == nullptr);
diff --git a/test/test.cc b/test/test.cc
index df69d4a..d6610a7 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -54,6 +54,166 @@ MultipartFormData &get_file_value(MultipartFormDataItems &files,
 #endif
 }
 
+#ifndef _WIN32
+class UnixSocketTest : public ::testing::Test {
+protected:
+  void TearDown() override { std::remove(pathname_.c_str()); }
+
+  void client_GET(const std::string &addr) {
+    httplib::Client cli{addr};
+    cli.set_address_family(AF_UNIX);
+    ASSERT_TRUE(cli.is_valid());
+
+    const auto &result = cli.Get(pattern_);
+    ASSERT_TRUE(result) << "error: " << result.error();
+
+    const auto &resp = result.value();
+    EXPECT_EQ(resp.status, StatusCode::OK_200);
+    EXPECT_EQ(resp.body, content_);
+  }
+
+  const std::string pathname_{"./httplib-server.sock"};
+  const std::string pattern_{"/hi"};
+  const std::string content_{"Hello World!"};
+};
+
+TEST_F(UnixSocketTest, pathname) {
+  httplib::Server svr;
+  svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) {
+    res.set_content(content_, "text/plain");
+  });
+
+  std::thread t{[&] {
+    ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(pathname_, 80));
+  }};
+  auto se = detail::scope_exit([&] {
+    svr.stop();
+    t.join();
+    ASSERT_FALSE(svr.is_running());
+  });
+
+  svr.wait_until_ready();
+  ASSERT_TRUE(svr.is_running());
+
+  client_GET(pathname_);
+}
+
+#if defined(__linux__) ||                                                      \
+    /* __APPLE__ */ (defined(SOL_LOCAL) && defined(SO_PEERPID))
+TEST_F(UnixSocketTest, PeerPid) {
+  httplib::Server svr;
+  std::string remote_port_val;
+  svr.Get(pattern_, [&](const httplib::Request &req, httplib::Response &res) {
+    res.set_content(content_, "text/plain");
+    remote_port_val = req.get_header_value("REMOTE_PORT");
+  });
+
+  std::thread t{[&] {
+    ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(pathname_, 80));
+  }};
+  auto se = detail::scope_exit([&] {
+    svr.stop();
+    t.join();
+    ASSERT_FALSE(svr.is_running());
+  });
+
+  svr.wait_until_ready();
+  ASSERT_TRUE(svr.is_running());
+
+  client_GET(pathname_);
+  EXPECT_EQ(std::to_string(getpid()), remote_port_val);
+}
+#endif
+
+#ifdef __linux__
+TEST_F(UnixSocketTest, abstract) {
+  constexpr char svr_path[]{"\x00httplib-server.sock"};
+  const std::string abstract_addr{svr_path, sizeof(svr_path) - 1};
+
+  httplib::Server svr;
+  svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) {
+    res.set_content(content_, "text/plain");
+  });
+
+  std::thread t{[&] {
+    ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(abstract_addr, 80));
+  }};
+  auto se = detail::scope_exit([&] {
+    svr.stop();
+    t.join();
+    ASSERT_FALSE(svr.is_running());
+  });
+
+  svr.wait_until_ready();
+  ASSERT_TRUE(svr.is_running());
+
+  client_GET(abstract_addr);
+}
+#endif
+
+TEST(SocketStream, is_writable_UNIX) {
+  int fds[2];
+  ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
+
+  const auto asSocketStream = [&](socket_t fd,
+                                  std::function<bool(Stream &)> func) {
+    return detail::process_client_socket(fd, 0, 0, 0, 0, func);
+  };
+  asSocketStream(fds[0], [&](Stream &s0) {
+    EXPECT_EQ(s0.socket(), fds[0]);
+    EXPECT_TRUE(s0.is_writable());
+
+    EXPECT_EQ(0, close(fds[1]));
+    EXPECT_FALSE(s0.is_writable());
+
+    return true;
+  });
+  EXPECT_EQ(0, close(fds[0]));
+}
+
+TEST(SocketStream, is_writable_INET) {
+  sockaddr_in addr;
+  memset(&addr, 0, sizeof(addr));
+  addr.sin_family = AF_INET;
+  addr.sin_port = htons(PORT + 1);
+  addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+
+  int disconnected_svr_sock = -1;
+  std::thread svr{[&] {
+    const int s = socket(AF_INET, SOCK_STREAM, 0);
+    ASSERT_LE(0, s);
+    ASSERT_EQ(0, ::bind(s, reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
+    ASSERT_EQ(0, listen(s, 1));
+    ASSERT_LE(0, disconnected_svr_sock = accept(s, nullptr, nullptr));
+    ASSERT_EQ(0, close(s));
+  }};
+  std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+  std::thread cli{[&] {
+    const int s = socket(AF_INET, SOCK_STREAM, 0);
+    ASSERT_LE(0, s);
+    ASSERT_EQ(0, connect(s, reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
+    ASSERT_EQ(0, close(s));
+  }};
+  cli.join();
+  svr.join();
+  ASSERT_NE(disconnected_svr_sock, -1);
+
+  const auto asSocketStream = [&](socket_t fd,
+                                  std::function<bool(Stream &)> func) {
+    return detail::process_client_socket(fd, 0, 0, 0, 0, func);
+  };
+  asSocketStream(disconnected_svr_sock, [&](Stream &ss) {
+    EXPECT_EQ(ss.socket(), disconnected_svr_sock);
+    EXPECT_FALSE(ss.is_writable());
+
+    return true;
+  });
+
+  ASSERT_EQ(0, close(disconnected_svr_sock));
+}
+#endif // #ifndef _WIN32
+
 TEST(ClientTest, MoveConstructible) {
   EXPECT_FALSE(std::is_copy_constructible<Client>::value);
   EXPECT_TRUE(std::is_nothrow_move_constructible<Client>::value);
@@ -4996,6 +5156,60 @@ TEST(KeepAliveTest, SSLClientReconnection) {
   ASSERT_TRUE(result);
   EXPECT_EQ(StatusCode::OK_200, result->status);
 }
+
+TEST(KeepAliveTest, SSLClientReconnectionPost) {
+  SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);
+  ASSERT_TRUE(svr.is_valid());
+  svr.set_keep_alive_timeout(1);
+  std::string content = "reconnect";
+
+  svr.Post("/hi", [](const httplib::Request &, httplib::Response &res) {
+    res.set_content("Hello World!", "text/plain");
+  });
+
+  auto f = std::async(std::launch::async, [&svr] { svr.listen(HOST, PORT); });
+  std::this_thread::sleep_for(std::chrono::milliseconds(200));
+
+  SSLClient cli(HOST, PORT);
+  cli.enable_server_certificate_verification(false);
+  cli.set_keep_alive(true);
+
+  auto result = cli.Post(
+      "/hi", content.size(),
+      [&content](size_t offset, size_t length, DataSink &sink) {
+        sink.write(content.c_str(), content.size());
+        return true;
+      },
+      "text/plain");
+  ASSERT_TRUE(result);
+  EXPECT_EQ(200, result->status);
+
+  std::this_thread::sleep_for(std::chrono::seconds(2));
+
+  // Recoonect
+  result = cli.Post(
+      "/hi", content.size(),
+      [&content](size_t offset, size_t length, DataSink &sink) {
+        sink.write(content.c_str(), content.size());
+        return true;
+      },
+      "text/plain");
+  ASSERT_TRUE(result);
+  EXPECT_EQ(200, result->status);
+
+  result = cli.Post(
+      "/hi", content.size(),
+      [&content](size_t offset, size_t length, DataSink &sink) {
+        sink.write(content.c_str(), content.size());
+        return true;
+      },
+      "text/plain");
+  ASSERT_TRUE(result);
+  EXPECT_EQ(200, result->status);
+
+  svr.stop();
+  f.wait();
+}
 #endif
 
 TEST(ClientProblemDetectionTest, ContentProvider) {
@@ -6970,166 +7184,6 @@ TEST(MultipartFormDataTest, ContentLength) {
 
 #endif
 
-#ifndef _WIN32
-class UnixSocketTest : public ::testing::Test {
-protected:
-  void TearDown() override { std::remove(pathname_.c_str()); }
-
-  void client_GET(const std::string &addr) {
-    httplib::Client cli{addr};
-    cli.set_address_family(AF_UNIX);
-    ASSERT_TRUE(cli.is_valid());
-
-    const auto &result = cli.Get(pattern_);
-    ASSERT_TRUE(result) << "error: " << result.error();
-
-    const auto &resp = result.value();
-    EXPECT_EQ(resp.status, StatusCode::OK_200);
-    EXPECT_EQ(resp.body, content_);
-  }
-
-  const std::string pathname_{"./httplib-server.sock"};
-  const std::string pattern_{"/hi"};
-  const std::string content_{"Hello World!"};
-};
-
-TEST_F(UnixSocketTest, pathname) {
-  httplib::Server svr;
-  svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) {
-    res.set_content(content_, "text/plain");
-  });
-
-  std::thread t{[&] {
-    ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(pathname_, 80));
-  }};
-  auto se = detail::scope_exit([&] {
-    svr.stop();
-    t.join();
-    ASSERT_FALSE(svr.is_running());
-  });
-
-  svr.wait_until_ready();
-  ASSERT_TRUE(svr.is_running());
-
-  client_GET(pathname_);
-}
-
-#if defined(__linux__) ||                                                      \
-    /* __APPLE__ */ (defined(SOL_LOCAL) && defined(SO_PEERPID))
-TEST_F(UnixSocketTest, PeerPid) {
-  httplib::Server svr;
-  std::string remote_port_val;
-  svr.Get(pattern_, [&](const httplib::Request &req, httplib::Response &res) {
-    res.set_content(content_, "text/plain");
-    remote_port_val = req.get_header_value("REMOTE_PORT");
-  });
-
-  std::thread t{[&] {
-    ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(pathname_, 80));
-  }};
-  auto se = detail::scope_exit([&] {
-    svr.stop();
-    t.join();
-    ASSERT_FALSE(svr.is_running());
-  });
-
-  svr.wait_until_ready();
-  ASSERT_TRUE(svr.is_running());
-
-  client_GET(pathname_);
-  EXPECT_EQ(std::to_string(getpid()), remote_port_val);
-}
-#endif
-
-#ifdef __linux__
-TEST_F(UnixSocketTest, abstract) {
-  constexpr char svr_path[]{"\x00httplib-server.sock"};
-  const std::string abstract_addr{svr_path, sizeof(svr_path) - 1};
-
-  httplib::Server svr;
-  svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) {
-    res.set_content(content_, "text/plain");
-  });
-
-  std::thread t{[&] {
-    ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(abstract_addr, 80));
-  }};
-  auto se = detail::scope_exit([&] {
-    svr.stop();
-    t.join();
-    ASSERT_FALSE(svr.is_running());
-  });
-
-  svr.wait_until_ready();
-  ASSERT_TRUE(svr.is_running());
-
-  client_GET(abstract_addr);
-}
-#endif
-
-TEST(SocketStream, is_writable_UNIX) {
-  int fds[2];
-  ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
-
-  const auto asSocketStream = [&](socket_t fd,
-                                  std::function<bool(Stream &)> func) {
-    return detail::process_client_socket(fd, 0, 0, 0, 0, func);
-  };
-  asSocketStream(fds[0], [&](Stream &s0) {
-    EXPECT_EQ(s0.socket(), fds[0]);
-    EXPECT_TRUE(s0.is_writable());
-
-    EXPECT_EQ(0, close(fds[1]));
-    EXPECT_FALSE(s0.is_writable());
-
-    return true;
-  });
-  EXPECT_EQ(0, close(fds[0]));
-}
-
-TEST(SocketStream, is_writable_INET) {
-  sockaddr_in addr;
-  memset(&addr, 0, sizeof(addr));
-  addr.sin_family = AF_INET;
-  addr.sin_port = htons(PORT + 1);
-  addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
-
-  int disconnected_svr_sock = -1;
-  std::thread svr{[&] {
-    const int s = socket(AF_INET, SOCK_STREAM, 0);
-    ASSERT_LE(0, s);
-    ASSERT_EQ(0, ::bind(s, reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
-    ASSERT_EQ(0, listen(s, 1));
-    ASSERT_LE(0, disconnected_svr_sock = accept(s, nullptr, nullptr));
-    ASSERT_EQ(0, close(s));
-  }};
-  std::this_thread::sleep_for(std::chrono::milliseconds(100));
-
-  std::thread cli{[&] {
-    const int s = socket(AF_INET, SOCK_STREAM, 0);
-    ASSERT_LE(0, s);
-    ASSERT_EQ(0, connect(s, reinterpret_cast<sockaddr *>(&addr), sizeof(addr)));
-    ASSERT_EQ(0, close(s));
-  }};
-  cli.join();
-  svr.join();
-  ASSERT_NE(disconnected_svr_sock, -1);
-
-  const auto asSocketStream = [&](socket_t fd,
-                                  std::function<bool(Stream &)> func) {
-    return detail::process_client_socket(fd, 0, 0, 0, 0, func);
-  };
-  asSocketStream(disconnected_svr_sock, [&](Stream &ss) {
-    EXPECT_EQ(ss.socket(), disconnected_svr_sock);
-    EXPECT_FALSE(ss.is_writable());
-
-    return true;
-  });
-
-  ASSERT_EQ(0, close(disconnected_svr_sock));
-}
-#endif // #ifndef _WIN32
-
 TEST(TaskQueueTest, IncreaseAtomicInteger) {
   static constexpr unsigned int number_of_tasks{1000000};
   std::atomic_uint count{0};