From 7196ac8a07100da06cfe278719a5b96a1ddfee1b Mon Sep 17 00:00:00 2001
From: "Sung, Po Han" <bernies@synology.com>
Date: Wed, 4 Sep 2024 17:38:05 +0800
Subject: [PATCH 1/3] Fix incorrect handling of Expect: 100-continue

Fix #1808
---
 httplib.h           |   4 +-
 test/CMakeLists.txt |   4 +-
 test/test.cc        | 100 ++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 106 insertions(+), 2 deletions(-)

diff --git a/httplib.h b/httplib.h
index b7be298..7c7bf04 100644
--- a/httplib.h
+++ b/httplib.h
@@ -6956,7 +6956,9 @@ Server::process_request(Stream &strm, bool close_connection,
       strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status,
                         status_message(status));
       break;
-    default: return write_response(strm, close_connection, req, res);
+    default:
+      connection_closed = true;
+      return write_response(strm, true, req, res);
     }
   }
 
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index d982253..75dd978 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -24,9 +24,11 @@ else()
     FetchContent_MakeAvailable(gtest)
 endif()
 
+find_package(curl REQUIRED)
+
 add_executable(httplib-test test.cc)
 target_compile_options(httplib-test PRIVATE "$<$<CXX_COMPILER_ID:MSVC>:/utf-8;/bigobj>")
-target_link_libraries(httplib-test PRIVATE httplib GTest::gtest_main)
+target_link_libraries(httplib-test PRIVATE httplib GTest::gtest_main CURL::libcurl)
 gtest_discover_tests(httplib-test)
 
 file(
diff --git a/test/test.cc b/test/test.cc
index 09a2eba..c75cdd9 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -1,6 +1,7 @@
 #include <httplib.h>
 #include <signal.h>
 
+#include <curl/curl.h>
 #include <gtest/gtest.h>
 
 #include <atomic>
@@ -12,6 +13,7 @@
 #include <stdexcept>
 #include <thread>
 #include <type_traits>
+#include <vector>
 
 #define SERVER_CERT_FILE "./cert.pem"
 #define SERVER_CERT2_FILE "./cert2.pem"
@@ -7606,3 +7608,101 @@ TEST(DirtyDataRequestTest, HeadFieldValueContains_CR_LF_NUL) {
   Client cli(HOST, PORT);
   cli.Get("/test", {{"Test", "_\n\r_\n\r_"}});
 }
+
+TEST(Expect100ContinueTest, ServerClosesConnection) {
+  static constexpr char reject[] = "Unauthorized";
+  static constexpr char accept[] = "Upload accepted";
+  constexpr size_t total_size = 10 * 1024 * 1024 * 1024ULL;
+
+  Server svr;
+
+  svr.set_expect_100_continue_handler([](const Request &req, Response &res) {
+    res.status = StatusCode::Unauthorized_401;
+    res.set_content(reject, "text/plain");
+    return res.status;
+  });
+  svr.Post("/", [&](const Request & /*req*/, Response &res) {
+    res.set_content(accept, "text/plain");
+  });
+
+  auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
+  auto se = detail::scope_exit([&] {
+    svr.stop();
+    thread.join();
+    ASSERT_FALSE(svr.is_running());
+  });
+
+  svr.wait_until_ready();
+
+  {
+    const auto curl = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>{
+        curl_easy_init(), &curl_easy_cleanup};
+    ASSERT_NE(curl, nullptr);
+
+    curl_easy_setopt(curl.get(), CURLOPT_URL, HOST);
+    curl_easy_setopt(curl.get(), CURLOPT_PORT, PORT);
+    curl_easy_setopt(curl.get(), CURLOPT_POST, 1L);
+    auto list = std::unique_ptr<curl_slist, decltype(&curl_slist_free_all)>{
+        curl_slist_append(nullptr, "Content-Type: application/octet-stream"),
+        &curl_slist_free_all};
+    ASSERT_NE(list, nullptr);
+    curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, list.get());
+
+    struct read_data {
+      size_t read_size;
+      size_t total_size;
+    } data = {0, total_size};
+    using read_callback_t =
+        size_t (*)(char *ptr, size_t size, size_t nmemb, void *userdata);
+    read_callback_t read_callback = [](char *ptr, size_t size, size_t nmemb,
+                                       void *userdata) -> size_t {
+      read_data *data = (read_data *)userdata;
+
+      if (!userdata || data->read_size >= data->total_size) { return 0; }
+
+      std::fill_n(ptr, size * nmemb, 'A');
+      data->read_size += size * nmemb;
+      return size * nmemb;
+    };
+    curl_easy_setopt(curl.get(), CURLOPT_READDATA, data);
+    curl_easy_setopt(curl.get(), CURLOPT_READFUNCTION, read_callback);
+
+    std::vector<char> buffer;
+    curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &buffer);
+    using write_callback_t =
+        size_t (*)(char *ptr, size_t size, size_t nmemb, void *userdata);
+    write_callback_t write_callback = [](char *ptr, size_t size, size_t nmemb,
+                                         void *userdata) -> size_t {
+      std::vector<char> *buffer = (std::vector<char> *)userdata;
+      buffer->reserve(buffer->size() + size * nmemb + 1);
+      buffer->insert(buffer->end(), (char *)ptr, (char *)ptr + size * nmemb);
+      return size * nmemb;
+    };
+    curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, write_callback);
+
+    {
+      const auto res = curl_easy_perform(curl.get());
+      ASSERT_EQ(res, CURLE_OK);
+    }
+
+    {
+      auto response_code = long{};
+      const auto res =
+          curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &response_code);
+      ASSERT_EQ(res, CURLE_OK);
+      ASSERT_EQ(response_code, StatusCode::Unauthorized_401);
+    }
+
+    {
+      auto dl = curl_off_t{};
+      const auto res = curl_easy_getinfo(curl.get(), CURLINFO_SIZE_DOWNLOAD_T, &dl);
+      ASSERT_EQ(res, CURLE_OK);
+      ASSERT_EQ(dl, sizeof reject - 1);
+    }
+
+    {
+      buffer.push_back('\0');
+      ASSERT_STRCASEEQ(buffer.data(), reject);
+    }
+  }
+}

From 4c2a608a0c85c9a5f92417255e4ce641b3cfb51a Mon Sep 17 00:00:00 2001
From: yhirose <yuji.hirose.bug@gmail.com>
Date: Wed, 4 Sep 2024 09:06:27 -0400
Subject: [PATCH 2/3] Fix GitHub Actions errors

---
 .github/workflows/test.yaml | 4 ++--
 test/Makefile               | 2 +-
 test/test.cc                | 2 +-
 3 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index 531fd4d..cf2104f 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -8,8 +8,8 @@ jobs:
     steps:
       - name: checkout
         uses: actions/checkout@v4
-      - name: install brotli
-        run: sudo apt-get update && sudo apt-get install -y libbrotli-dev
+      - name: install libraries
+        run: sudo apt-get update && sudo apt-get install -y libbrotli-dev libcurl4-openssl-dev
       - name: build and run tests
         run: cd test && make -j4
       - name: run fuzz test target
diff --git a/test/Makefile b/test/Makefile
index 5468488..96ebec9 100644
--- a/test/Makefile
+++ b/test/Makefile
@@ -18,7 +18,7 @@ ZLIB_SUPPORT = -DCPPHTTPLIB_ZLIB_SUPPORT -lz
 BROTLI_DIR = $(PREFIX)/opt/brotli
 BROTLI_SUPPORT = -DCPPHTTPLIB_BROTLI_SUPPORT -I$(BROTLI_DIR)/include -L$(BROTLI_DIR)/lib -lbrotlicommon -lbrotlienc -lbrotlidec
 
-TEST_ARGS = gtest/gtest-all.cc gtest/gtest_main.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT) -pthread
+TEST_ARGS = gtest/gtest-all.cc gtest/gtest_main.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT) -pthread -lcurl
 
 # By default, use standalone_fuzz_target_runner.
 # This runner does no fuzzing, but simply executes the inputs
diff --git a/test/test.cc b/test/test.cc
index 95b993e..231b290 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -7614,7 +7614,7 @@ TEST(Expect100ContinueTest, ServerClosesConnection) {
 
   Server svr;
 
-  svr.set_expect_100_continue_handler([](const Request &req, Response &res) {
+  svr.set_expect_100_continue_handler([](const Request &/*req*/, Response &res) {
     res.status = StatusCode::Unauthorized_401;
     res.set_content(reject, "text/plain");
     return res.status;

From bd1da4346abd74673a48ac6942bf8ab5de8c4603 Mon Sep 17 00:00:00 2001
From: yhirose <yuji.hirose.bug@gmail.com>
Date: Wed, 4 Sep 2024 09:30:14 -0400
Subject: [PATCH 3/3] Disable Expect100ContinueTest test on Windows

---
 test/test.cc | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/test/test.cc b/test/test.cc
index 231b290..e2ae142 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -1,7 +1,9 @@
 #include <httplib.h>
 #include <signal.h>
 
+#ifndef _WIN32
 #include <curl/curl.h>
+#endif
 #include <gtest/gtest.h>
 
 #include <atomic>
@@ -7607,6 +7609,7 @@ TEST(DirtyDataRequestTest, HeadFieldValueContains_CR_LF_NUL) {
   cli.Get("/test", {{"Test", "_\n\r_\n\r_"}});
 }
 
+#ifndef _WIN32
 TEST(Expect100ContinueTest, ServerClosesConnection) {
   static constexpr char reject[] = "Unauthorized";
   static constexpr char accept[] = "Upload accepted";
@@ -7704,3 +7707,4 @@ TEST(Expect100ContinueTest, ServerClosesConnection) {
     }
   }
 }
+#endif