From be2a1fdb9d74cec8ee7486ad59f6c709c72de2e4 Mon Sep 17 00:00:00 2001
From: yhirose <yuji.hirose.bug@gmail.com>
Date: Tue, 25 Sep 2012 22:09:56 -0400
Subject: [PATCH] Refactoring and DSL support.

---
 example/sample.cc |  75 +++++++++++++++++---------
 httpsvrkit.h      | 135 ++++++++++++++++++++++++++++++++++------------
 2 files changed, 149 insertions(+), 61 deletions(-)

diff --git a/example/sample.cc b/example/sample.cc
index ddf65d1..039ccb2 100644
--- a/example/sample.cc
+++ b/example/sample.cc
@@ -10,53 +10,76 @@
 
 using namespace httpsvrkit;
 
-int dump_request(Context& cxt)
+std::string dump_request(Context& cxt)
 {
-    auto& body = cxt.response.body;
+    std::string s;
     char buf[BUFSIZ];
 
-    body += "================================\n";
+    s += "================================\n";
 
-    sprintf(buf, "Method: %s, URL: %s\n",
-        cxt.request.method.c_str(),
-        cxt.request.url.c_str());
+    sprintf(buf, "Method: %s\n", cxt.request.method.c_str());
+    s += buf;
 
-    body += buf;
+    sprintf(buf, "URL: %s\n", cxt.request.url.c_str());
+    s += buf;
+
+    std::string query;
+    for (auto it = cxt.request.query.begin(); it != cxt.request.query.end(); ++it) {
+       const auto& x = *it;
+       sprintf(buf, "(%s:%s)", x.first.c_str(), x.second.c_str());
+       query += buf;
+    }
+    sprintf(buf, "QUERY: %s\n", query.c_str());
+    s += buf;
 
     //for (const auto& x : cxt.request.headers) {
     for (auto it = cxt.request.headers.begin(); it != cxt.request.headers.end(); ++it) {
        const auto& x = *it;
        sprintf(buf, "%s: %s\n", x.first.c_str(), x.second.c_str());
-       body += buf;
+       s += buf;
     }
 
-    body += "================================\n";
+    s += "================================\n";
 
-    return 200;
+    return s;
 }
 
 int main(void)
 {
-    Server svr;
+    if (true) {
+        // DSL style
+        HTTP_SERVER("localhost", 1234) {
 
-    svr.get("/", [](Context& cxt) -> int {
-        dump_request(cxt);
-        return 200;
-    });
+            GET("/", {
+                res.set_redirect("/home");
+            });
 
-    svr.post("/item", [](Context& cxt) -> int {
-        dump_request(cxt);
-        cxt.response.body += cxt.request.url;
-        return 200;
-    });
+            GET("/home", {
+                res.set_content(dump_request(cxt));
+            });
+        }
+    } else {
+        // Regular style
+        Server svr("localhost", 1234);
 
-    svr.get("/item/([^/]+)", [](Context& cxt) -> int {
-        dump_request(cxt);
-        cxt.response.body += cxt.request.params[0];
-        return 200;
-    });
+        svr.get("/", [](Context& cxt) {
+            cxt.response.set_redirect("/home");
+        });
 
-    svr.run("localhost", 1234);
+        svr.get("/home", [](Context& cxt) {
+            cxt.response.set_content(dump_request(cxt));
+        });
+
+        svr.post("/item", [](Context& cxt) {
+            cxt.response.set_content(dump_request(cxt));
+        });
+
+        svr.get("/item/([^/]+)", [](Context& cxt) {
+            cxt.response.set_content(dump_request(cxt));
+        });
+
+        svr.run();
+    }
 }
 
 // vim: et ts=4 sw=4 cin cino={1s ff=unix
diff --git a/httpsvrkit.h b/httpsvrkit.h
index e66c808..9e53a6d 100644
--- a/httpsvrkit.h
+++ b/httpsvrkit.h
@@ -5,10 +5,20 @@
 //  The Boost Software License 1.0
 //
 
+#ifndef HTTPSVRKIT_H
+#define HTTPSVRKIT_H
+
 #ifdef _WIN32
 //#define _CRT_SECURE_NO_WARNINGS
 #define _CRT_NONSTDC_NO_DEPRECATE
 
+#ifndef SO_SYNCHRONOUS_NONALERT
+#define SO_SYNCHRONOUS_NONALERT 0x20;
+#endif
+#ifndef SO_OPENTYPE
+#define SO_OPENTYPE 0x7008
+#endif
+
 #include <fcntl.h>
 #include <io.h>
 #include <winsock2.h>
@@ -50,8 +60,12 @@ struct Request {
 
 // HTTP response
 struct Response {
+    int         status;
     MultiMap    headers;
     std::string body;
+
+    void set_redirect(const char* url);
+    void set_content(const std::string& s, const char* content_type = "text/plain");
 };
 
 struct Context {
@@ -62,21 +76,23 @@ struct Context {
 // HTTP server
 class Server {
 public:
-    typedef std::function<int (Context& context)> Handler;
+    typedef std::function<void (Context& context)> Handler;
 
-    Server();
+    Server(const char* ipaddr_or_hostname, int port);
     ~Server();
 
     void get(const char* pattern, Handler handler);
     void post(const char* pattern, Handler handler);
 
-    bool run(const char* ipaddr_or_hostname, int port);
+    bool run();
     void stop();
 
 private:
     void process_request(FILE* fp_read, FILE* fp_write);
 
-    socket_t sock_;
+    const std::string ipaddr_or_hostname_;
+    const int         port_;
+    socket_t          sock_;
     std::vector<std::pair<std::regex, Handler>> get_handlers_;
     std::vector<std::pair<std::string, Handler>> post_handlers_;
 };
@@ -102,8 +118,13 @@ void split(const char* b, const char* e, char d, Fn fn)
     }
 }
 
-inline socket_t create_server_socket(const const char* ipaddr_or_hostname, int port)
+inline socket_t create_server_socket(const char* ipaddr_or_hostname, int port)
 {
+#ifdef _WIN32
+    int opt = SO_SYNCHRONOUS_NONALERT;
+    setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char*)&opt, sizeof(opt));
+#endif
+
     // Create a server socket
     socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
     if (sock == -1) {
@@ -111,8 +132,8 @@ inline socket_t create_server_socket(const const char* ipaddr_or_hostname, int p
     }
 
     // Make 'reuse address' option available
-    int opt = 1;
-    setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char*)&opt, sizeof(opt));
+    int yes = 1;
+    setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char*)&yes, sizeof(yes));
 
     // Get a host entry info
     struct hostent* hp;
@@ -128,7 +149,6 @@ inline socket_t create_server_socket(const const char* ipaddr_or_hostname, int p
     addr.sin_port = htons(port);
 
     if (::bind(sock, (struct sockaddr*)&addr, sizeof(addr)) != 0) {
-        puts("(error)\n");
         return -1;
     }
 
@@ -150,21 +170,27 @@ inline void close_socket(socket_t sock)
 #endif
 }
 
-inline Server::Server()
-    : sock_(-1)
+void Response::set_redirect(const char* url)
+{
+    headers.insert(std::make_pair("Location", url));
+    status = 302;
+}
+
+void Response::set_content(const std::string& s, const char* content_type)
+{
+    body = s;
+    headers.insert(std::make_pair("Content-Type", content_type));
+    status = 200;
+}
+
+inline Server::Server(const char* ipaddr_or_hostname, int port)
+    : ipaddr_or_hostname_(ipaddr_or_hostname)
+    , port_(port)
+    , sock_(-1)
 {
 #ifdef _WIN32
     WSADATA wsaData;
     WSAStartup(0x0002, &wsaData);
-
-#ifndef SO_SYNCHRONOUS_NONALERT
-#define SO_SYNCHRONOUS_NONALERT 0x20;
-#endif
-#ifndef SO_OPENTYPE
-#define SO_OPENTYPE 0x7008
-#endif
-    int opt = SO_SYNCHRONOUS_NONALERT;
-    setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char*)&opt, sizeof(opt));
 #endif
 }
 
@@ -185,9 +211,9 @@ inline void Server::post(const char* pattern, Handler handler)
     post_handlers_.push_back(std::make_pair(pattern, handler));
 }
 
-inline bool Server::run(const const char*ipaddr_or_hostname, int port)
+inline bool Server::run()
 {
-    sock_ = create_server_socket(ipaddr_or_hostname, port);
+    sock_ = create_server_socket(ipaddr_or_hostname_.c_str(), port_);
     if (sock_ == -1) {
         return false;
     }
@@ -206,11 +232,11 @@ inline bool Server::run(const const char*ipaddr_or_hostname, int port)
 
 #ifdef _WIN32
         int osfhandle = _open_osfhandle(fd, _O_RDONLY);
-        FILE* fp_read = fdopen(osfhandle, "r");
-        FILE* fp_write = fdopen(osfhandle, "w");
+        FILE* fp_read = fdopen(osfhandle, "rb");
+        FILE* fp_write = fdopen(osfhandle, "wb");
 #else
-        FILE* fp_read = fdopen(fd, "r");
-        FILE* fp_write = fdopen(fd, "w");
+        FILE* fp_read = fdopen(fd, "rb");
+        FILE* fp_write = fdopen(fd, "wb");
 #endif
 
         process_request(fp_read, fp_write);
@@ -282,13 +308,37 @@ inline void read_headers(FILE* fp, Map& headers)
     }
 }
 
-inline void write_plain_text(FILE* fp, const char* s)
+inline const char* get_header_value(const MultiMap& map, const char* key, const char* def)
 {
-    fprintf(fp, "HTTP/1.0 200 OK\r\n");
-    fprintf(fp, "Content-type: text/plain\r\n");
+    auto it = map.find(key);
+    if (it != map.end()) {
+        return it->second.c_str();
+    }
+    return def;
+}
+
+inline void write_response(FILE* fp, const Response& response)
+{
+    fprintf(fp, "HTTP/1.0 %d OK\r\n", response.status);
     fprintf(fp, "Connection: close\r\n");
+
+    for (auto it = response.headers.begin(); it != response.headers.end(); ++it) {
+        if (it->first != "Content-Type" && it->second != "Content-Length") {
+            fprintf(fp, "%s: %s\r\n", it->first.c_str(), it->second.c_str());
+        }
+    }
+
+    if (!response.body.empty()) {
+        auto content_type = get_header_value(response.headers, "Content-Type", "text/plain");
+        fprintf(fp, "Content-Type: %s\r\n", content_type);
+        fprintf(fp, "Content-Length: %ld\r\n", response.body.size());
+    }
+
     fprintf(fp, "\r\n");
-    fprintf(fp, "%s", s);
+
+    if (!response.body.empty()) {
+        fprintf(fp, "%s", response.body.c_str());
+    }
 }
 
 inline void write_error(FILE* fp, int status)
@@ -331,7 +381,7 @@ inline void Server::process_request(FILE* fp_read, FILE* fp_write)
     read_headers(fp_read, cxt.request.headers);
 
     // Routing
-    int status = 404;
+    cxt.response.status = 404;
 
     if (cxt.request.method == "GET") {
         for (auto it = get_handlers_.begin(); it != get_handlers_.end(); ++it) {
@@ -343,22 +393,37 @@ inline void Server::process_request(FILE* fp_read, FILE* fp_write)
                 for (size_t i = 1; i < m.size(); i++) {
                     cxt.request.params.push_back(m[i]);
                 }
-                status = handler(cxt);
+                handler(cxt);
+                break;
             }
         }
     } else if (cxt.request.method == "POST") {
         // TODO: parse body
     } else {
-        status = 400;
+        cxt.response.status = 400;
     }
 
-    if (status == 200) {
-        write_plain_text(fp_write, cxt.response.body.c_str());
+    if (200 <= cxt.response.status && cxt.response.status < 400) {
+        write_response(fp_write, cxt.response);
     } else {
-        write_error(fp_write, status);
+        write_error(fp_write, cxt.response.status);
     }
 }
 
+#define HTTP_SERVER(host, port) \
+    for (std::shared_ptr<httpsvrkit::Server> svr = std::make_shared<httpsvrkit::Server>(host, port); \
+         svr; \
+         svr->run(), svr.reset())
+
+#define GET(url, body) \
+    svr->get(url, [](httpsvrkit::Context& cxt) { \
+        const auto& req = cxt.request; \
+        auto& res = cxt.response; \
+        body \
+    });
+
 } // namespace httpsvrkit
 
+#endif
+
 // vim: et ts=4 sw=4 cin cino={1s ff=unix