From 5324b3d661caaed4c9962e3a0cb4e48bb60877bf Mon Sep 17 00:00:00 2001
From: yhirose <yuji.hirose.bug@gmail.com>
Date: Thu, 12 Dec 2019 22:44:54 -0500
Subject: [PATCH] Improved multipart form data interface

---
 README.md    | 11 +++++------
 httplib.h    | 44 +++++++++++++++++++-------------------------
 test/test.cc | 19 ++++++++++---------
 3 files changed, 34 insertions(+), 40 deletions(-)

diff --git a/README.md b/README.md
index 60e8cf9..ddc8205 100644
--- a/README.md
+++ b/README.md
@@ -119,15 +119,14 @@ svr.Get("/stream", [&](const Request &req, Response &res) {
 svr.Post("/content_receiver",
   [&](const Request &req, Response &res, const ContentReader &content_reader) {
     if (req.is_multipart_form_data()) {
-      MultipartFiles files;
+      MultipartFormDataItems files;
       content_reader(
-        [&](const std::string &name, const MultipartFile &file) {
-          files.emplace(name, file);
+        [&](const MultipartFormData &file) {
+          files.push_back(file);
           return true;
         },
-        [&](const std::string &name, const char *data, size_t data_length) {
-          auto &file = files.find(name)->second;
-          file.content.append(data, data_length);
+        [&](const char *data, size_t data_length) {
+          files.back().content.append(data, data_length);
           return true;
         });
     } else {
diff --git a/httplib.h b/httplib.h
index 7dccf49..b185270 100644
--- a/httplib.h
+++ b/httplib.h
@@ -224,20 +224,17 @@ using ContentReceiver =
     std::function<bool(const char *data, size_t data_length)>;
 
 using MultipartContentHeader =
-    std::function<bool(const std::string &name, const MultipartFormData &file)>;
-
-using MultipartContentReceiver =
-    std::function<bool(const std::string& name, const char *data, size_t data_length)>;
+    std::function<bool(const MultipartFormData &file)>;
 
 class ContentReader {
   public:
     using Reader = std::function<bool(ContentReceiver receiver)>;
-    using MultipartReader = std::function<bool(MultipartContentHeader header, MultipartContentReceiver receiver)>;
+    using MultipartReader = std::function<bool(MultipartContentHeader header, ContentReceiver receiver)>;
 
     ContentReader(Reader reader, MultipartReader muitlpart_reader)
       : reader_(reader), muitlpart_reader_(muitlpart_reader) {}
 
-    bool operator()(MultipartContentHeader header, MultipartContentReceiver receiver) const {
+    bool operator()(MultipartContentHeader header, ContentReceiver receiver) const {
       return muitlpart_reader_(header, receiver);
     }
 
@@ -590,12 +587,12 @@ private:
                                           Request &req, Response &res,
                                           ContentReceiver receiver,
                                           MultipartContentHeader multipart_header,
-                                          MultipartContentReceiver multipart_receiver);
+                                          ContentReceiver multipart_receiver);
   bool read_content_core(Stream &strm, bool last_connection,
                          Request &req, Response &res,
                          ContentReceiver receiver,
                          MultipartContentHeader mulitpart_header,
-                         MultipartContentReceiver multipart_receiver);
+                         ContentReceiver multipart_receiver);
 
   virtual bool process_and_close_socket(socket_t sock);
 
@@ -2011,7 +2008,7 @@ public:
         while (pos != std::string::npos) {
           // Empty line
           if (pos == 0) {
-            if (!header_callback(name_, file_)) {
+            if (!header_callback(file_)) {
               is_valid_ = false;
               is_done_ = false;
               return false;
@@ -2028,8 +2025,7 @@ public:
             if (std::regex_match(header, m, re_content_type)) {
               file_.content_type = m[1];
             } else if (std::regex_match(header, m, re_content_disposition)) {
-              name_ = m[1];
-              file_.name = name_;
+              file_.name = m[1];
               file_.filename = m[2];
             }
           }
@@ -2047,7 +2043,7 @@ public:
           if (pos == std::string::npos) {
             pos = buf_.size();
           }
-          if (!content_callback(name_, buf_.data(), pos)) {
+          if (!content_callback(buf_.data(), pos)) {
             is_valid_ = false;
             is_done_ = false;
             return false;
@@ -2063,7 +2059,7 @@ public:
 
           auto pos = buf_.find(pattern);
           if (pos != std::string::npos) {
-            if (!content_callback(name_, buf_.data(), pos)) {
+            if (!content_callback(buf_.data(), pos)) {
               is_valid_ = false;
               is_done_ = false;
               return false;
@@ -2073,7 +2069,7 @@ public:
             buf_.erase(0, pos + pattern.size());
             state_ = 4;
           } else {
-            if (!content_callback(name_, buf_.data(), pattern.size())) {
+            if (!content_callback(buf_.data(), pattern.size())) {
               is_valid_ = false;
               is_done_ = false;
               return false;
@@ -2118,7 +2114,7 @@ public:
 
 private:
   void clear_file_info() {
-    name_.clear();
+    file_.name.clear();
     file_.filename.clear();
     file_.content_type.clear();
   }
@@ -2132,7 +2128,6 @@ private:
   size_t is_valid_ = false;
   size_t is_done_ = false;
   size_t off_ = 0;
-  std::string name_;
   MultipartFormData file_;
 };
 
@@ -2973,6 +2968,7 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
 
 inline bool Server::read_content(Stream &strm, bool last_connection,
                                  Request &req, Response &res) {
+  MultipartFormDataMap::iterator cur;
   auto ret = read_content_core(strm, last_connection, req, res,
     // Regular
     [&](const char *buf, size_t n) {
@@ -2981,14 +2977,12 @@ inline bool Server::read_content(Stream &strm, bool last_connection,
       return true;
     },
     // Multipart
-    [&](const std::string &name, const MultipartFormData &file) {
-      req.files.emplace(name, file);
+    [&](const MultipartFormData &file) {
+      cur = req.files.emplace(file.name, file);
       return true;
     },
-    [&](const std::string &name, const char *buf, size_t n) {
-      // TODO: handle elements with a same key
-      auto it = req.files.find(name);
-      auto &content = it->second.content;
+    [&](const char *buf, size_t n) {
+      auto &content = cur->second.content;
       if (content.size() + n > content.max_size()) { return false; }
       content.append(buf, n);
       return true;
@@ -3008,7 +3002,7 @@ Server::read_content_with_content_receiver(Stream &strm, bool last_connection,
                                            Request &req, Response &res,
                                            ContentReceiver receiver,
                                            MultipartContentHeader multipart_header,
-                                           MultipartContentReceiver multipart_receiver) {
+                                           ContentReceiver multipart_receiver) {
   return read_content_core(strm, last_connection, req, res,
       receiver, multipart_header, multipart_receiver);
 }
@@ -3018,7 +3012,7 @@ Server::read_content_core(Stream &strm, bool last_connection,
                           Request &req, Response &res,
                           ContentReceiver receiver,
                           MultipartContentHeader mulitpart_header,
-                          MultipartContentReceiver multipart_receiver) {
+                          ContentReceiver multipart_receiver) {
   detail::MultipartFormDataParser multipart_form_data_parser;
   ContentReceiver out;
 
@@ -3181,7 +3175,7 @@ inline bool Server::routing(Request &req, Response &res, Stream &strm,
           return read_content_with_content_receiver(strm, last_connection, req, res,
                                                     receiver, nullptr, nullptr);
         },
-        [&](MultipartContentHeader header, MultipartContentReceiver receiver) {
+        [&](MultipartContentHeader header, ContentReceiver receiver) {
           return read_content_with_content_receiver(strm, last_connection, req, res,
                                                     nullptr, header, receiver);
         }
diff --git a/test/test.cc b/test/test.cc
index a210348..a189d8f 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -30,9 +30,11 @@ const std::string JSON_DATA = "{\"hello\":\"world\"}";
 
 const string LARGE_DATA = string(1024 * 1024 * 100, '@'); // 100MB
 
-MultipartFormData& get_file_value(MultipartFormDataMap &files, const char *key) {
-  auto it = files.find(key);
-  if (it != files.end()) { return it->second; }
+MultipartFormData& get_file_value(MultipartFormDataItems &files, const char *key) {
+  auto it = std::find_if(files.begin(), files.end(), [&](const MultipartFormData &file) {
+    return file.name == key;
+  });
+  if (it != files.end()) { return *it; }
   throw std::runtime_error("invalid mulitpart form data name error");
 }
 
@@ -801,15 +803,14 @@ protected:
         .Post("/content_receiver",
               [&](const Request & req, Response &res, const ContentReader &content_reader) {
                 if (req.is_multipart_form_data()) {
-                  MultipartFormDataMap files;
+                  MultipartFormDataItems files;
                   content_reader(
-                    [&](const std::string &name, const MultipartFormData &file) {
-                      files.emplace(name, file);
+                    [&](const MultipartFormData &file) {
+                      files.push_back(file);
                       return true;
                     },
-                    [&](const std::string &name, const char *data, size_t data_length) {
-                      auto &file = files.find(name)->second;
-                      file.content.append(data, data_length);
+                    [&](const char *data, size_t data_length) {
+                      files.back().content.append(data, data_length);
                       return true;
                     });