From 033bc357234eb7db21494347d87d261fb5037066 Mon Sep 17 00:00:00 2001
From: yhirose <yuji.hirose.bug@gmail.com>
Date: Mon, 2 Dec 2019 07:11:12 -0500
Subject: [PATCH] Improve multipart content reader interface

---
 README.md    |  8 ++++----
 httplib.h    | 44 +++++++++++++++++++++++---------------------
 test/test.cc |  8 ++++----
 3 files changed, 31 insertions(+), 29 deletions(-)

diff --git a/README.md b/README.md
index b509a10..36c74ce 100644
--- a/README.md
+++ b/README.md
@@ -121,14 +121,14 @@ svr.Post("/content_receiver",
     if (req.is_multipart_form_data()) {
       MultipartFiles files;
       content_reader(
+        [&](const std::string &name, const MultipartFile &file) {
+          files.emplace(name, file);
+          return true;
+        },
         [&](const std::string &name, const char *data, size_t data_length) {
           auto &file = files.find(name)->second;
           file.content.append(data, data_length);
           return true;
-        },
-        [&](const std::string &name, const MultipartFile &file) {
-          files.emplace(name, file);
-          return true;
         });
     } else {
       std::string body;
diff --git a/httplib.h b/httplib.h
index b7d8dbb..914433f 100644
--- a/httplib.h
+++ b/httplib.h
@@ -225,22 +225,22 @@ using MultipartFormDataItems = std::vector<MultipartFormData>;
 using ContentReceiver =
     std::function<bool(const char *data, size_t data_length)>;
 
-using MultipartContentReceiver =
-    std::function<bool(const std::string& name, const char *data, size_t data_length)>;
-
 using MultipartContentHeader =
     std::function<bool(const std::string &name, const MultipartFile &file)>;
 
+using MultipartContentReceiver =
+    std::function<bool(const std::string& name, const char *data, size_t data_length)>;
+
 class ContentReader {
   public:
     using Reader = std::function<bool(ContentReceiver receiver)>;
-    using MultipartReader = std::function<bool(MultipartContentReceiver receiver, MultipartContentHeader header)>;
+    using MultipartReader = std::function<bool(MultipartContentHeader header, MultipartContentReceiver receiver)>;
 
     ContentReader(Reader reader, MultipartReader muitlpart_reader)
       : reader_(reader), muitlpart_reader_(muitlpart_reader) {}
 
-    bool operator()(MultipartContentReceiver receiver, MultipartContentHeader header) const {
-      return muitlpart_reader_(receiver, header);
+    bool operator()(MultipartContentHeader header, MultipartContentReceiver receiver) const {
+      return muitlpart_reader_(header, receiver);
     }
 
     bool operator()(ContentReceiver receiver) const {
@@ -591,13 +591,13 @@ private:
   bool read_content_with_content_receiver(Stream &strm, bool last_connection,
                                           Request &req, Response &res,
                                           ContentReceiver receiver,
-                                          MultipartContentReceiver multipart_receiver,
-                                          MultipartContentHeader multipart_header);
+                                          MultipartContentHeader multipart_header,
+                                          MultipartContentReceiver multipart_receiver);
   bool read_content_core(Stream &strm, bool last_connection,
                          Request &req, Response &res,
                          ContentReceiver receiver,
-                         MultipartContentReceiver multipart_receiver,
-                         MultipartContentHeader mulitpart_header);
+                         MultipartContentHeader mulitpart_header,
+                         MultipartContentReceiver multipart_receiver);
 
   virtual bool process_and_close_socket(socket_t sock);
 
@@ -2796,11 +2796,17 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
 inline bool Server::read_content(Stream &strm, bool last_connection,
                                  Request &req, Response &res) {
   auto ret = read_content_core(strm, last_connection, req, res,
+    // Regular
     [&](const char *buf, size_t n) {
       if (req.body.size() + n > req.body.max_size()) { return false; }
       req.body.append(buf, n);
       return true;
     },
+    // Multipart
+    [&](const std::string &name, const MultipartFile &file) {
+      req.files.emplace(name, file);
+      return true;
+    },
     [&](const std::string &name, const char *buf, size_t n) {
       // TODO: handle elements with a same key
       auto it = req.files.find(name);
@@ -2808,10 +2814,6 @@ inline bool Server::read_content(Stream &strm, bool last_connection,
       if (content.size() + n > content.max_size()) { return false; }
       content.append(buf, n);
       return true;
-    },
-    [&](const std::string &name, const MultipartFile &file) {
-      req.files.emplace(name, file);
-      return true;
     }
   );
 
@@ -2827,18 +2829,18 @@ inline bool
 Server::read_content_with_content_receiver(Stream &strm, bool last_connection,
                                            Request &req, Response &res,
                                            ContentReceiver receiver,
-                                           MultipartContentReceiver multipart_receiver,
-                                           MultipartContentHeader multipart_header) {
+                                           MultipartContentHeader multipart_header,
+                                           MultipartContentReceiver multipart_receiver) {
   return read_content_core(strm, last_connection, req, res,
-      receiver, multipart_receiver, multipart_header);
+      receiver, multipart_header, multipart_receiver);
 }
 
 inline bool
 Server::read_content_core(Stream &strm, bool last_connection,
                           Request &req, Response &res,
                           ContentReceiver receiver,
-                          MultipartContentReceiver multipart_receiver,
-                          MultipartContentHeader mulitpart_header) {
+                          MultipartContentHeader mulitpart_header,
+                          MultipartContentReceiver multipart_receiver) {
   detail::MultipartFormDataParser multipart_form_data_parser;
   ContentReceiver out;
 
@@ -3001,9 +3003,9 @@ inline bool Server::routing(Request &req, Response &res, Stream &strm,
           return read_content_with_content_receiver(strm, last_connection, req, res,
                                                     receiver, nullptr, nullptr);
         },
-        [&](MultipartContentReceiver receiver, MultipartContentHeader header) {
+        [&](MultipartContentHeader header, MultipartContentReceiver receiver) {
           return read_content_with_content_receiver(strm, last_connection, req, res,
-                                                    nullptr, receiver, header);
+                                                    nullptr, header, receiver);
         }
       );
 
diff --git a/test/test.cc b/test/test.cc
index f60235c..cc92796 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -761,14 +761,14 @@ protected:
                 if (req.is_multipart_form_data()) {
                   MultipartFiles files;
                   content_reader(
+                    [&](const std::string &name, const MultipartFile &file) {
+                      files.emplace(name, file);
+                      return true;
+                    },
                     [&](const std::string &name, const char *data, size_t data_length) {
                       auto &file = files.find(name)->second;
                       file.content.append(data, data_length);
                       return true;
-                    },
-                    [&](const std::string &name, const MultipartFile &file) {
-                      files.emplace(name, file);
-                      return true;
                     });
 
                   EXPECT_EQ(5u, files.size());