From 77a77f6d2df8d648cfd481345c000acf211f05b5 Mon Sep 17 00:00:00 2001 From: yhirose Date: Sun, 23 May 2021 18:17:55 -0400 Subject: [PATCH] Added set_default_headers on Server --- httplib.h | 15 +++++++++ test/test.cc | 86 ++++++++++++++++++++++++++++++++++------------------ 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/httplib.h b/httplib.h index 838fbbc..a8d273a 100644 --- a/httplib.h +++ b/httplib.h @@ -667,6 +667,8 @@ public: Server &set_tcp_nodelay(bool on); Server &set_socket_options(SocketOptions socket_options); + Server &set_default_headers(Headers headers); + Server &set_keep_alive_max_count(size_t count); Server &set_keep_alive_timeout(time_t sec); @@ -786,6 +788,8 @@ private: int address_family_ = AF_UNSPEC; bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; SocketOptions socket_options_ = default_socket_options; + + Headers default_headers_; }; enum class Error { @@ -4427,6 +4431,11 @@ inline Server &Server::set_socket_options(SocketOptions socket_options) { return *this; } +inline Server &Server::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); + return *this; +} + inline Server &Server::set_keep_alive_max_count(size_t count) { keep_alive_max_count_ = count; return *this; @@ -5131,6 +5140,12 @@ Server::process_request(Stream &strm, bool close_connection, res.version = "HTTP/1.1"; + for (const auto &header : default_headers_) { + if (res.headers.find(header.first) == res.headers.end()) { + res.headers.insert(header); + } + } + #ifdef _WIN32 // TODO: Increase FD_SETSIZE statically (libzmq), dynamically (MySQL). #else diff --git a/test/test.cc b/test/test.cc index d502176..2492de3 100644 --- a/test/test.cc +++ b/test/test.cc @@ -5,9 +5,9 @@ #include #include #include +#include #include #include -#include #define SERVER_CERT_FILE "./cert.pem" #define SERVER_CERT2_FILE "./cert2.pem" @@ -437,26 +437,6 @@ TEST(ChunkedEncodingTest, WithResponseHandlerAndContentReceiver) { EXPECT_EQ(out, body); } -TEST(DefaultHeadersTest, FromHTTPBin) { - Client cli("httpbin.org"); - cli.set_default_headers({make_range_header({{1, 10}})}); - cli.set_connection_timeout(5); - - { - auto res = cli.Get("/range/32"); - ASSERT_TRUE(res); - EXPECT_EQ("bcdefghijk", res->body); - EXPECT_EQ(206, res->status); - } - - { - auto res = cli.Get("/range/32"); - ASSERT_TRUE(res); - EXPECT_EQ("bcdefghijk", res->body); - EXPECT_EQ(206, res->status); - } -} - TEST(RangeTest, FromHTTPBin) { auto host = "httpbin.org"; @@ -968,7 +948,7 @@ TEST(RedirectFromPageWithContent, Redirect) { TEST(PathUrlEncodeTest, PathUrlEncode) { Server svr; - svr.Get("/foo", [](const Request & req, Response &res) { + svr.Get("/foo", [](const Request &req, Response &res) { auto a = req.params.find("a"); if (a != req.params.end()) { res.set_content((*a).second, "text/plain"); @@ -1420,7 +1400,8 @@ protected: const auto &d = *data; auto out_len = std::min(static_cast(length), DATA_CHUNK_SIZE); - auto ret = sink.write(&d[static_cast(offset)], out_len); + auto ret = + sink.write(&d[static_cast(offset)], out_len); EXPECT_TRUE(ret); return true; }, @@ -3199,12 +3180,11 @@ static bool send_request(time_t read_timeout_sec, const std::string &req, std::string *resp = nullptr) { auto error = Error::Success; - auto client_sock = - detail::create_client_socket(HOST, PORT, AF_UNSPEC, false, nullptr, - /*connection_timeout_sec=*/5, 0, - /*read_timeout_sec=*/5, 0, - /*write_timeout_sec=*/5, 0, - std::string(), error); + auto client_sock = detail::create_client_socket( + HOST, PORT, AF_UNSPEC, false, nullptr, + /*connection_timeout_sec=*/5, 0, + /*read_timeout_sec=*/5, 0, + /*write_timeout_sec=*/5, 0, std::string(), error); if (client_sock == INVALID_SOCKET) { return false; } @@ -3684,6 +3664,54 @@ TEST(GetWithParametersTest, GetWithParameters2) { ASSERT_FALSE(svr.is_running()); } +TEST(ClientDefaultHeadersTest, DefaultHeaders) { + Client cli("httpbin.org"); + cli.set_default_headers({make_range_header({{1, 10}})}); + cli.set_connection_timeout(5); + + { + auto res = cli.Get("/range/32"); + ASSERT_TRUE(res); + EXPECT_EQ("bcdefghijk", res->body); + EXPECT_EQ(206, res->status); + } + + { + auto res = cli.Get("/range/32"); + ASSERT_TRUE(res); + EXPECT_EQ("bcdefghijk", res->body); + EXPECT_EQ(206, res->status); + } +} + +TEST(ServerDefaultHeadersTest, DefaultHeaders) { + Server svr; + svr.set_default_headers({{"Hello", "World"}}); + + svr.Get("/", [&](const Request & /*req*/, Response &res) { + res.set_content("ok", "text/plain"); + }); + + auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); }); + while (!svr.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + + Client cli("localhost", PORT); + + auto res = cli.Get("/"); + + ASSERT_TRUE(res); + EXPECT_EQ(200, res->status); + EXPECT_EQ("ok", res->body); + EXPECT_EQ("World", res->get_header_value("Hello")); + + svr.stop(); + listen_thread.join(); + ASSERT_FALSE(svr.is_running()); +} + #ifdef CPPHTTPLIB_OPENSSL_SUPPORT TEST(KeepAliveTest, ReadTimeoutSSL) { SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);