From 744e8e7071f969be775fa7708ee05891289c48bc Mon Sep 17 00:00:00 2001
From: yhirose <yuji.hirose.bug@gmail.com>
Date: Fri, 12 Apr 2019 23:34:27 -0400
Subject: [PATCH] Fix #144

---
 httplib.h    | 56 ++++++++++++++++++++++++++++++++++++++--------------
 test/test.cc | 50 ++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 91 insertions(+), 15 deletions(-)

diff --git a/httplib.h b/httplib.h
index 99ca369..278e2ff 100644
--- a/httplib.h
+++ b/httplib.h
@@ -85,7 +85,9 @@ typedef int socket_t;
  */
 #define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5
 #define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0
+#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5
 #define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192
+#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH std::numeric_limits<uint64_t>::max()
 
 namespace httplib {
 
@@ -233,6 +235,7 @@ public:
   void set_logger(Logger logger);
 
   void set_keep_alive_max_count(size_t count);
+  void set_payload_max_length(uint64_t length);
 
   int bind_to_any_port(const char *host, int socket_flags = 0);
   bool listen_after_bind();
@@ -247,6 +250,7 @@ protected:
                        bool &connection_close);
 
   size_t keep_alive_max_count_;
+  size_t payload_max_length_;
 
 private:
   typedef std::vector<std::pair<std::regex, Handler>> Handlers;
@@ -762,6 +766,7 @@ inline const char *status_message(int status) {
   case 400: return "Bad Request";
   case 403: return "Forbidden";
   case 404: return "Not Found";
+  case 413: return "Payload Too Large";
   case 414: return "Request-URI Too Long";
   case 415: return "Unsupported Media Type";
   default:
@@ -782,12 +787,12 @@ inline const char *get_header_value(const Headers &headers, const char *key,
 }
 
 inline uint64_t get_header_value_uint64(const Headers &headers, const char *key,
-								int def = 0) {
-	auto it = headers.find(key);
-	if (it != headers.end()) {
-		return std::strtoull(it->second.data(), nullptr, 10);
-	}
-	return def;
+                                        int def = 0) {
+  auto it = headers.find(key);
+  if (it != headers.end()) {
+    return std::strtoull(it->second.data(), nullptr, 10);
+  }
+  return def;
 }
 
 inline bool read_headers(Stream &strm, Headers &headers) {
@@ -881,7 +886,9 @@ inline bool read_content_chunked(Stream &strm, std::string &out) {
 }
 
 template <typename T>
-bool read_content(Stream &strm, T &x, Progress progress = Progress()) {
+bool read_content(Stream &strm, T &x, uint64_t payload_max_length,
+                  bool &exceed_payload_max_length,
+                  Progress progress = Progress()) {
   if (has_header(x.headers, "Content-Length")) {
     auto len = get_header_value_uint64(x.headers, "Content-Length", 0);
     if (len == 0) {
@@ -891,6 +898,15 @@ bool read_content(Stream &strm, T &x, Progress progress = Progress()) {
         return read_content_chunked(strm, x.body);
       }
     }
+
+    if ((len > payload_max_length) ||
+        // For 32-bit platform
+        (sizeof(size_t) < sizeof(uint64_t) &&
+         len > std::numeric_limits<size_t>::max())) {
+      exceed_payload_max_length = true;
+      return false;
+    }
+
     return read_content_with_length(strm, x.body, len, progress);
   } else {
     const auto &encoding =
@@ -1427,8 +1443,9 @@ inline const std::string &BufferStream::get_buffer() const { return buffer; }
 
 // HTTP server implementation
 inline Server::Server()
-    : keep_alive_max_count_(5), is_running_(false), svr_sock_(INVALID_SOCKET),
-      running_threads_(0) {
+    : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT),
+      payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false),
+      svr_sock_(INVALID_SOCKET), running_threads_(0) {
 #ifndef _WIN32
   signal(SIGPIPE, SIG_IGN);
 #endif
@@ -1484,6 +1501,10 @@ inline void Server::set_keep_alive_max_count(size_t count) {
   keep_alive_max_count_ = count;
 }
 
+inline void Server::set_payload_max_length(uint64_t length) {
+  payload_max_length_ = length;
+}
+
 inline int Server::bind_to_any_port(const char *host, int socket_flags) {
   return bind_internal(host, 0, socket_flags);
 }
@@ -1702,8 +1723,7 @@ inline bool Server::listen_internal() {
         std::lock_guard<std::mutex> guard(running_threads_mutex_);
         running_threads_--;
       }
-    })
-        .detach();
+    }).detach();
   }
 
   // TODO: Use thread pool...
@@ -1789,10 +1809,12 @@ inline bool Server::process_request(Stream &strm, bool last_connection,
 
   // Body
   if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") {
-    if (!detail::read_content(strm, req)) {
-      res.status = 400;
+    bool exceed_payload_max_length = false;
+    if (!detail::read_content(strm, req, payload_max_length_,
+                              exceed_payload_max_length)) {
+      res.status = exceed_payload_max_length ? 413 : 400;
       write_response(strm, last_connection, req, res);
-      return true;
+      return !exceed_payload_max_length;
     }
 
     const auto &content_type = req.get_header_value("Content-Type");
@@ -1975,7 +1997,11 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
 
   // Body
   if (req.method != "HEAD") {
-    if (!detail::read_content(strm, res, req.progress)) { return false; }
+    bool exceed_payload_max_length = false;
+    if (!detail::read_content(strm, res, std::numeric_limits<uint64_t>::max(),
+                              exceed_payload_max_length, req.progress)) {
+      return false;
+    }
 
     if (res.get_header_value("Content-Encoding") == "gzip") {
 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
diff --git a/test/test.cc b/test/test.cc
index deb6f6d..bb71ec8 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -1288,6 +1288,56 @@ TEST_F(ServerUpDownTest, QuickStartStop) {
   // --gtest_filter=ServerUpDownTest.QuickStartStop --gtest_repeat=1000
 }
 
+class PayloadMaxLengthTest : public ::testing::Test {
+protected:
+  PayloadMaxLengthTest()
+      : cli_(HOST, PORT)
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+        ,
+        svr_(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE)
+#endif
+  {
+  }
+
+  virtual void SetUp() {
+    svr_.set_payload_max_length(8);
+
+    svr_.Post("/test", [&](const Request & /*req*/, Response &res) {
+      res.set_content("test", "text/plain");
+    });
+
+    t_ = thread([&]() { ASSERT_TRUE(svr_.listen(HOST, PORT)); });
+
+    while (!svr_.is_running()) {
+      msleep(1);
+    }
+  }
+
+  virtual void TearDown() {
+    svr_.stop();
+    t_.join();
+  }
+
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+  SSLClient cli_;
+  SSLServer svr_;
+#else
+  Client cli_;
+  Server svr_;
+#endif
+  thread t_;
+};
+
+TEST_F(PayloadMaxLengthTest, ExceedLimit) {
+  auto res = cli_.Post("/test", "123456789", "text/plain");
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(413, res->status);
+
+  res = cli_.Post("/test", "12345678", "text/plain");
+  ASSERT_TRUE(res != nullptr);
+  EXPECT_EQ(200, res->status);
+}
+
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 TEST(SSLClientTest, ServerNameIndication) {
   SSLClient cli("httpbin.org", 443);