From e5dd410256de59e4ff0da91d87102d8e43f50cdd Mon Sep 17 00:00:00 2001
From: yhirose <yuji.hirose.bug@gmail.com>
Date: Sat, 15 Aug 2020 05:53:49 -0400
Subject: [PATCH] Added set_content_provider without content length

---
 httplib.h    | 111 ++++++++++++++++++++++++++++++++++++++-------------
 test/test.cc |  43 +++++++++++++++++++-
 2 files changed, 124 insertions(+), 30 deletions(-)

diff --git a/httplib.h b/httplib.h
index 2df93c7..bf2f64c 100644
--- a/httplib.h
+++ b/httplib.h
@@ -299,7 +299,7 @@ private:
 using ContentProvider =
     std::function<bool(size_t offset, size_t length, DataSink &sink)>;
 
-using ChunkedContentProvider =
+using ContentProviderWithoutLength =
     std::function<bool(size_t offset, DataSink &sink)>;
 
 using ContentReceiver =
@@ -404,8 +404,12 @@ struct Response {
       size_t length, const char *content_type, ContentProvider provider,
       std::function<void()> resource_releaser = [] {});
 
+  void set_content_provider(
+      const char *content_type, ContentProviderWithoutLength provider,
+      std::function<void()> resource_releaser = [] {});
+
   void set_chunked_content_provider(
-      const char *content_type, ChunkedContentProvider provider,
+      const char *content_type, ContentProviderWithoutLength provider,
       std::function<void()> resource_releaser = [] {});
 
   Response() = default;
@@ -423,6 +427,7 @@ struct Response {
   size_t content_length_ = 0;
   ContentProvider content_provider_;
   std::function<void()> content_provider_resource_releaser_;
+  bool is_chunked_content_provider = false;
 };
 
 class Stream {
@@ -2664,19 +2669,19 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider,
                              size_t offset, size_t length, T is_shutting_down) {
   size_t begin_offset = offset;
   size_t end_offset = offset + length;
-
   auto ok = true;
-
   DataSink data_sink;
+
   data_sink.write = [&](const char *d, size_t l) {
     if (ok) {
       offset += l;
       if (!write_data(strm, d, l)) { ok = false; }
     }
   };
+
   data_sink.is_writable = [&](void) { return ok && strm.is_writable(); };
 
-  while (ok && offset < end_offset && !is_shutting_down()) {
+  while (offset < end_offset && !is_shutting_down()) {
     if (!content_provider(offset, end_offset - offset, data_sink)) {
       return -1;
     }
@@ -2686,6 +2691,34 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider,
   return static_cast<ssize_t>(offset - begin_offset);
 }
 
+template <typename T>
+inline ssize_t write_content_without_length(Stream &strm,
+                                            ContentProvider content_provider,
+                                            T is_shutting_down) {
+  size_t offset = 0;
+  auto data_available = true;
+  auto ok = true;
+  DataSink data_sink;
+
+  data_sink.write = [&](const char *d, size_t l) {
+    if (ok) {
+      offset += l;
+      if (!write_data(strm, d, l)) { ok = false; }
+    }
+  };
+
+  data_sink.done = [&](void) { data_available = false; };
+
+  data_sink.is_writable = [&](void) { return ok && strm.is_writable(); };
+
+  while (data_available && !is_shutting_down()) {
+    if (!content_provider(offset, 0, data_sink)) { return -1; }
+    if (!ok) { return -1; }
+  }
+
+  return static_cast<ssize_t>(offset);
+}
+
 template <typename T, typename U>
 inline ssize_t write_content_chunked(Stream &strm,
                                      ContentProvider content_provider,
@@ -2693,7 +2726,6 @@ inline ssize_t write_content_chunked(Stream &strm,
   size_t offset = 0;
   auto data_available = true;
   ssize_t total_written_length = 0;
-
   auto ok = true;
   DataSink data_sink;
 
@@ -3544,10 +3576,11 @@ Response::set_content_provider(size_t in_length, const char *content_type,
     return provider(offset, length, sink);
   };
   content_provider_resource_releaser_ = resource_releaser;
+  is_chunked_content_provider = false;
 }
 
-inline void Response::set_chunked_content_provider(
-    const char *content_type, ChunkedContentProvider provider,
+inline void Response::set_content_provider(
+    const char *content_type, ContentProviderWithoutLength provider,
     std::function<void()> resource_releaser) {
   set_header("Content-Type", content_type);
   content_length_ = 0;
@@ -3555,6 +3588,19 @@ inline void Response::set_chunked_content_provider(
     return provider(offset, sink);
   };
   content_provider_resource_releaser_ = resource_releaser;
+  is_chunked_content_provider = false;
+}
+
+inline void Response::set_chunked_content_provider(
+    const char *content_type, ContentProviderWithoutLength provider,
+    std::function<void()> resource_releaser) {
+  set_header("Content-Type", content_type);
+  content_length_ = 0;
+  content_provider_ = [provider](size_t offset, size_t, DataSink &sink) {
+    return provider(offset, sink);
+  };
+  content_provider_resource_releaser_ = resource_releaser;
+  is_chunked_content_provider = true;
 }
 
 // Rstream implementation
@@ -3893,7 +3939,7 @@ inline bool Server::write_response(Stream &strm, bool close_connection,
   }
 
   if (!res.has_header("Content-Type") &&
-      (!res.body.empty() || res.content_length_ > 0)) {
+      (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) {
     res.set_header("Content-Type", "text/plain");
   }
 
@@ -3939,11 +3985,13 @@ inline bool Server::write_response(Stream &strm, bool close_connection,
       res.set_header("Content-Length", std::to_string(length));
     } else {
       if (res.content_provider_) {
-        res.set_header("Transfer-Encoding", "chunked");
-        if (type == detail::EncodingType::Gzip) {
-          res.set_header("Content-Encoding", "gzip");
-        } else if (type == detail::EncodingType::Brotli) {
-          res.set_header("Content-Encoding", "br");
+        if (res.is_chunked_content_provider) {
+          res.set_header("Transfer-Encoding", "chunked");
+          if (type == detail::EncodingType::Gzip) {
+            res.set_header("Content-Encoding", "gzip");
+          } else if (type == detail::EncodingType::Brotli) {
+            res.set_header("Content-Encoding", "br");
+          }
         }
       } else {
         res.set_header("Content-Length", "0");
@@ -4033,7 +4081,7 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
     return this->svr_sock_ == INVALID_SOCKET;
   };
 
-  if (res.content_length_) {
+  if (res.content_length_ > 0) {
     if (req.ranges.empty()) {
       if (detail::write_content(strm, res.content_provider_, 0,
                                 res.content_length_, is_shutting_down) < 0) {
@@ -4055,25 +4103,32 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
       }
     }
   } else {
-    auto type = detail::encoding_type(req, res);
+    if (res.is_chunked_content_provider) {
+      auto type = detail::encoding_type(req, res);
 
-    std::shared_ptr<detail::compressor> compressor;
-    if (type == detail::EncodingType::Gzip) {
+      std::shared_ptr<detail::compressor> compressor;
+      if (type == detail::EncodingType::Gzip) {
 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
-      compressor = std::make_shared<detail::gzip_compressor>();
+        compressor = std::make_shared<detail::gzip_compressor>();
 #endif
-    } else if (type == detail::EncodingType::Brotli) {
+      } else if (type == detail::EncodingType::Brotli) {
 #ifdef CPPHTTPLIB_BROTLI_SUPPORT
-      compressor = std::make_shared<detail::brotli_compressor>();
+        compressor = std::make_shared<detail::brotli_compressor>();
 #endif
-    } else {
-      compressor = std::make_shared<detail::nocompressor>();
-    }
-    assert(compressor != nullptr);
+      } else {
+        compressor = std::make_shared<detail::nocompressor>();
+      }
+      assert(compressor != nullptr);
 
-    if (detail::write_content_chunked(strm, res.content_provider_,
-                                      is_shutting_down, *compressor) < 0) {
-      return false;
+      if (detail::write_content_chunked(strm, res.content_provider_,
+                                        is_shutting_down, *compressor) < 0) {
+        return false;
+      }
+    } else {
+      if (detail::write_content_without_length(strm, res.content_provider_,
+                                               is_shutting_down) < 0) {
+        return false;
+      }
     }
   }
   return true;
diff --git a/test/test.cc b/test/test.cc
index b603acd..c04bb0c 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -1895,8 +1895,7 @@ TEST_F(ServerTest, ClientStop) {
       auto res = cli_.Get("/streamed-cancel",
                           [&](const char *, uint64_t) { return true; });
       ASSERT_TRUE(!res);
-      EXPECT_TRUE(res.error() == Error::Canceled ||
-                  res.error() == Error::Read);
+      EXPECT_TRUE(res.error() == Error::Canceled || res.error() == Error::Read);
     }));
   }
 
@@ -2730,6 +2729,46 @@ TEST(ServerStopTest, StopServerWithChunkedTransmission) {
   ASSERT_FALSE(svr.is_running());
 }
 
+TEST(StreamingTest, NoContentLengthStreaming) {
+  Server svr;
+
+  svr.Get("/stream", [](const Request & /*req*/, Response &res) {
+    res.set_content_provider(
+        "text/plain", [](size_t offset, DataSink &sink) {
+          if (offset < 6) {
+            sink.os << (offset < 3 ? "a" : "b");
+          } else {
+            sink.done();
+          }
+          return true;
+        });
+  });
+
+  auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); });
+  while (!svr.is_running()) {
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
+  }
+
+  Client client(HOST, PORT);
+
+  auto get_thread = std::thread([&client]() {
+    auto res = client.Get("/stream", [](const char *data, size_t len) -> bool {
+      EXPECT_EQ("aaabbb", std::string(data, len));
+      return true;
+    });
+  });
+
+  // Give GET time to get a few messages.
+  std::this_thread::sleep_for(std::chrono::milliseconds(500));
+
+  svr.stop();
+
+  listen_thread.join();
+  get_thread.join();
+
+  ASSERT_FALSE(svr.is_running());
+}
+
 TEST(MountTest, Unmount) {
   Server svr;