diff --git a/include/jwt/algorithm.hpp b/include/jwt/algorithm.hpp index 9cdb08d..46721d4 100644 --- a/include/jwt/algorithm.hpp +++ b/include/jwt/algorithm.hpp @@ -241,14 +241,20 @@ public: { std::error_code ec{}; - EVP_PKEY* pkey = load_key(key); + static auto evpkey_deletor = [](EVP_PKEY* ptr) { + if (ptr) EVP_PKEY_free(ptr); + }; + + std::unique_ptr<EVP_PKEY, decltype(evpkey_deletor)> + pkey{load_key(key), evpkey_deletor}; + if (!pkey) { //TODO: set valid error code return {std::string{}, ec}; } //TODO: Use stack string here ? - std::string sign = evp_digest(pkey, data, ec); + std::string sign = evp_digest(pkey.get(), data, ec); if (ec) { //TODO: handle error_code return {std::move(sign), ec}; @@ -257,7 +263,7 @@ public: if (Hasher::type != EVP_PKEY_EC) { return {std::move(sign), ec}; } else { - sign = public_key_ser(pkey, sign, ec); + sign = public_key_ser(pkey.get(), sign, ec); } return {std::move(sign), ec}; @@ -316,7 +322,7 @@ private: return std::string{}; } - uint32_t len = 0; + unsigned long len = 0; if (EVP_DigestSignFinal(mdctx_ptr.get(), nullptr, &len) != 1) { //TODO: set appropriate error_code @@ -327,7 +333,7 @@ private: sign.resize(len); //Get the signature - if (EVP_DigestSignFinal(mdctx_ptr.get(), &sign[0], &len) != 1) { + if (EVP_DigestSignFinal(mdctx_ptr.get(), (unsigned char*)&sign[0], &len) != 1) { //TODO: set appropriate error_code return std::string{}; } @@ -342,10 +348,14 @@ private: // (optionaly) an associated private key std::string new_sign; - auto eckey_deletor = [](EC_KEY* ptr) { + static auto eckey_deletor = [](EC_KEY* ptr) { if (ptr) EC_KEY_free(ptr); }; + static auto ecsig_deletor = [](ECDSA_SIG* ptr) { + if (ptr) ECDSA_SIG_free(ptr); + }; + std::unique_ptr<EC_KEY, decltype(eckey_deletor)> ec_key{EVP_PKEY_get1_EC_KEY(pkey), eckey_deletor}; @@ -356,9 +366,13 @@ private: uint32_t degree = EC_GROUP_get_degree(EC_KEY_get0_group(ec_key.get())); - ECDSA_SIG* ec_sig = d2i_ECDSA_SIG(nullptr, - (const unsigned char**)&sign[0], - sign.length()); + + std::unique_ptr<ECDSA_SIG, decltype(ecsig_deletor)> + ec_sig{d2i_ECDSA_SIG(nullptr, + (const unsigned char**)&sign[0], + sign.length()), + ecsig_deletor}; + if (!ec_sig) { //TODO set a valid error code return std::string{}; @@ -377,7 +391,7 @@ private: #endif - ECDSA_SIG_get0(ec_sig, &ec_sig_r, &ec_sig_s); + ECDSA_SIG_get0(ec_sig.get(), &ec_sig_r, &ec_sig_s); auto r_len = BN_num_bytes(ec_sig_r); auto s_len = BN_num_bytes(ec_sig_s); diff --git a/include/jwt/base64.hpp b/include/jwt/base64.hpp index 761dcea..faaad0f 100644 --- a/include/jwt/base64.hpp +++ b/include/jwt/base64.hpp @@ -127,6 +127,7 @@ std::string base64_decode(const char* in, size_t len) result.resize(128); int i = 0; size_t bytes_rem = len; + size_t bytes_wr = 0; constexpr static const DMap dmap{}; @@ -150,6 +151,7 @@ std::string base64_decode(const char* in, size_t len) i += 3; in += 4; } + bytes_wr = i; switch(bytes_rem) { case 4: @@ -157,6 +159,7 @@ std::string base64_decode(const char* in, size_t len) auto third = dmap.at(in[2]); auto fourth = dmap.at(in[3]); result[i + 2] = (third << 6) | fourth; + bytes_wr++; //FALLTHROUGH } case 3: @@ -164,6 +167,7 @@ std::string base64_decode(const char* in, size_t len) auto second = dmap.at(in[1]); auto third = dmap.at(in[2]); result[i + 1] = (second << 4) | (third >> 2); + bytes_wr++; //FALLTHROUGH } case 2: @@ -171,9 +175,12 @@ std::string base64_decode(const char* in, size_t len) auto first = dmap.at(in[0]); auto second = dmap.at(in[1]); result[i] = (first << 2) | (second >> 4); + bytes_wr++; } }; + result.resize(bytes_wr); + return result; } diff --git a/include/jwt/jwt.hpp b/include/jwt/jwt.hpp index 9ec6efc..5df200b 100644 --- a/include/jwt/jwt.hpp +++ b/include/jwt/jwt.hpp @@ -2,6 +2,8 @@ #define JWT_HPP #include <cassert> +#include <cstring> +#include <set> #include <string> #include <ostream> @@ -156,12 +158,16 @@ struct write_interface template <typename Derived> struct base64_enc_dec { + /*! + */ std::string base64_encode(bool with_pretty = false) const { std::string jstr = to_json_str(*static_cast<const Derived*>(this), with_pretty); return jwt::base64_encode(jstr.c_str(), jstr.length()); } + /*! + */ static std::string base64_decode(const std::string& encoded_str) { return jwt::base64_decode(encoded_str.c_str(), encoded_str.length()); @@ -233,6 +239,7 @@ public: // Exposed APIs json_t create_json_obj() const { json_t obj = json_t::object(); + //TODO: should be able to do with string_view obj["typ"] = type_to_str(typ_).to_string(); obj["alg"] = alg_to_str(alg_).to_string(); @@ -247,10 +254,88 @@ private: // Data members enum type typ_ = type::JWT; }; + /*! + * JWT Payload */ -struct jwt_payload +struct jwt_payload: write_interface + , base64_enc_dec<jwt_payload> { +public: // 'tors + /*! + */ + jwt_payload() = default; + + /// Default copy and assignment operations + jwt_payload(const jwt_payload&) = default; + jwt_payload& operator=(const jwt_payload&) = default; + + ~jwt_payload() = default; + +public: // Exposed APIs + /*! + */ + template <typename T> + bool add_claim(const std::string& cname, T&& cvalue, bool overwrite=false) + { + // Duplicate claim names not allowed + // if overwrite flag is set to true. + auto itr = claim_names_.find(cname); + if (itr != claim_names_.end() && !overwrite) { + return false; + } + + // Add it to the known set of claims + claim_names_.emplace(cname.data(), cname.length()); + + //Add it to the json payload + //TODO: claim name copied twice inside json + //and in the set + payload_[cname.data()] = std::forward<T>(cvalue); + + return true; + } + + /*! + */ + bool has_claim(const std::string& cname) const noexcept + { + return claim_names_.count(cname); + } + + /*! + */ + template <typename T> + bool has_claim_with_value(const std::string& cname, T&& cvalue) const + { + auto itr = claim_names_.find(cname); + if (itr == claim_names_.end()) return false; + + return (cvalue == payload_[cname]); + } + + /*! + */ + const json_t& create_json_obj() const + { + return payload_; + } + +private: + /*! + */ + struct case_compare { + bool operator()(const std::string& lhs, const std::string& rhs) const + { + int ret = strcasecmp(lhs.c_str(), rhs.c_str()); + return (ret < 0); + } + }; + + /// JSON object containing payload + json_t payload_; + /// The set of claim names in the payload + std::set<std::string, case_compare> claim_names_; }; /*! diff --git a/include/jwt/test/compile.txt b/include/jwt/test/compile.txt new file mode 100644 index 0000000..b2cc8eb --- /dev/null +++ b/include/jwt/test/compile.txt @@ -0,0 +1 @@ +g++ -std=c++14 -I /usr/local/Cellar/openssl/1.0.2j/include/ -I /Users/amuralid/dev_test/cpp-jwt/include/ -o test_rsa test_rsa.cc -L /usr/local/Cellar//openssl/1.0.2j/lib/ -lssl -lcrypto diff --git a/include/jwt/test/test_jwt_payload b/include/jwt/test/test_jwt_payload new file mode 100755 index 0000000..2e39d24 Binary files /dev/null and b/include/jwt/test/test_jwt_payload differ diff --git a/include/jwt/test/test_jwt_payload.cc b/include/jwt/test/test_jwt_payload.cc new file mode 100644 index 0000000..809b669 --- /dev/null +++ b/include/jwt/test/test_jwt_payload.cc @@ -0,0 +1,34 @@ +#include <iostream> +#include "jwt/jwt.hpp" + +void basic_payload_test() +{ + jwt::jwt_payload jp; + jp.add_claim("iss", "myself"); + jp.add_claim("exp", 1234567); + jp.add_claim("Exp", 1234567, true); + + auto jstr = jwt::to_json_str(jp); + std::cout << jstr << std::endl; + + auto enc = jp.base64_encode(); + std::cout << "Base64 enc: " << enc << std::endl; + + auto dec = jp.base64_decode(enc); + std::cout << "Base64 dec: " << dec << std::endl; + std::cout << "Base64 dec: " << jstr << std::endl; + + assert (jstr == dec && "Encoded and decoded messages do not match"); + assert (jp.has_claim("exp") && "Claim exp must exist"); + assert (jp.has_claim("Exp") && "Claim Exp must exist"); + + assert (!jp.has_claim("aud") && "Claim aud does not exist"); + assert (jp.has_claim_with_value("exp", 1234567) && "Claim exp with value 1234567 does not exist"); + + return; +} + +int main() { + basic_payload_test(); + return 0; +} diff --git a/include/jwt/test/test_rsa b/include/jwt/test/test_rsa new file mode 100755 index 0000000..c9adf83 Binary files /dev/null and b/include/jwt/test/test_rsa differ diff --git a/include/jwt/test/test_rsa.cc b/include/jwt/test/test_rsa.cc new file mode 100644 index 0000000..a40624a --- /dev/null +++ b/include/jwt/test/test_rsa.cc @@ -0,0 +1,47 @@ +#include <iostream> +#include "jwt/algorithm.hpp" + +static const char* rsa_2048_pem = +R"(-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDC2kwAziXUf33m +iqWp0yG6o259+nj7hpQLC4UT0Hmz0wmvreDJ/yNbSgOvsxvVdvzL2IaRZ+Gi5mo0 +lswWvL6IGz7PZO0kXTq9sdBnNqMOx27HddV9e/2/p0MgibJTbgywY2Sk23QYhJpq +Kq/nU0xlBfSaI5ddZ2RC9ZNkVeGawUKYksTruhAVJqviHN8BoK6VowP5vcxyyOWH +TK9KruDqzCIhqwRTeo0spokBkTN/LCuhVivcHAzUiJVtB4qAiTI9L/zkzhjpKz9P +45aLU54rj011gG8U/6E1USh5nMnPkr+d3oLfkhfS3Zs3kJVdyFQWZpQxiTaI92Fd +2wLvbS0HAgMBAAECggEAD8dTnkETSSjlzhRuI9loAtAXM3Zj86JLPLW7GgaoxEoT +n7lJ2bGicFMHB2ROnbOb9vnas82gtOtJsGaBslmoaCckp/C5T1eJWTEb+i+vdpPp +wZcmKZovyyRFSE4+NYlU17fEv6DRvuaGBpDcW7QgHJIl45F8QWEM+msee2KE+V4G +z/9vAQ+sOlvsb4mJP1tJIBx9Lb5loVREwCRy2Ha9tnWdDNar8EYkOn8si4snPT+E +3ZCy8mlcZyUkZeiS/HdtydxZfoiwrSRYamd1diQpPhWCeRteQ802a7ds0Y2YzgfF +UaYjNuRQm7zA//hwbXS7ELPyNMU15N00bajlG0tUOQKBgQDnLy01l20OneW6A2cI +DIDyYhy5O7uulsaEtJReUlcjEDMkin8b767q2VZHb//3ZH+ipnRYByUUyYUhdOs2 +DYRGGeAebnH8wpTT4FCYxUsIUpDfB7RwfdBONgaKewTJz/FPswy1Ye0b5H2c6vVi +m2FZ33HQcoZ3wvFFqyGVnMzpOwKBgQDXxL95yoxUGKa8vMzcE3Cn01szh0dFq0sq +cFpM+HWLVr84CItuG9H6L0KaStEEIOiJsxOVpcXfFFhsJvOGhMA4DQTwH4WuXmXp +1PoVMDlV65PYqvhzwL4+QhvZO2bsrEunITXOmU7CI6kilnAN3LuP4HbqZgoX9lqP +I31VYzLupQKBgGEYck9w0s/xxxtR9ILv5XRnepLdoJzaHHR991aKFKjYU/KD7JDK +INfoAhGs23+HCQhCCtkx3wQVA0Ii/erM0II0ueluD5fODX3TV2ZibnoHW2sgrEsW +vFcs36BnvIIaQMptc+f2QgSV+Z/fGsKYadG6Q+39O7au/HB7SHayzWkjAoGBAMgt +Fzslp9TpXd9iBWjzfCOnGUiP65Z+GWkQ/SXFqD+SRir0+m43zzGdoNvGJ23+Hd6K +TdQbDJ0uoe4MoQeepzoZEgi4JeykVUZ/uVfo+nh06yArVf8FxTm7WVzLGGzgV/uA ++wtl/cRtEyAsk1649yW/KHPEIP8kJdYAJeoO8xSlAoGAERMrkFR7KGYZG1eFNRdV +mJMq+Ibxyw8ks/CbiI+n3yUyk1U8962ol2Q0T4qjBmb26L5rrhNQhneM4e8mo9FX +LlQapYkPvkdrqW0Bp72A/UNAvcGTmN7z5OCJGMUutx2hmEAlrYmpLKS8pM/p9zpK +tEOtzsP5GMDYVlEp1jYSjzQ= +-----END PRIVATE KEY-----)"; + +void basic_rsa_test() +{ + jwt::string_view sv = rsa_2048_pem; + jwt::string_view d = "Some random data string"; + + auto res = jwt::PEMSign<jwt::algo::RS256>::sign(sv, d); + + std::cout << res.first << std::endl; +} + +int main() { + basic_rsa_test(); + return 0; +}