diff --git a/include/jwt/impl/jwt.ipp b/include/jwt/impl/jwt.ipp index 41257bb..5039229 100644 --- a/include/jwt/impl/jwt.ipp +++ b/include/jwt/impl/jwt.ipp @@ -322,7 +322,10 @@ template void jwt_object::set_parameters( params::detail::headers_param&& header, Rest&&... rargs) { - //TODO: add kid support + for (const auto& elem : header.get()) { + header_.add_header(std::move(elem.first), std::move(elem.second)); + } + set_parameters(std::forward(rargs)...); } diff --git a/include/jwt/jwt.hpp b/include/jwt/jwt.hpp index adb2741..0970705 100644 --- a/include/jwt/jwt.hpp +++ b/include/jwt/jwt.hpp @@ -125,6 +125,48 @@ jwt::string_view reg_claims_to_str(enum registered_claims claim) noexcept assert (0 && "Code not reached"); } +/** + * A helper class that enables reuse of the + * std::set container with custom comparator. + */ +struct jwt_set +{ + /** + * Transparent comparator. + * @note: C++14 only. + */ + struct case_compare + { + using is_transparent = std::true_type; + + bool operator()(const std::string& lhs, const std::string& rhs) const + { + int ret = strcmp(lhs.c_str(), rhs.c_str()); + return (ret < 0); + } + + bool operator()(const jwt::string_view lhs, const jwt::string_view rhs) const + { + int ret = strcmp(lhs.data(), rhs.data()); + return (ret < 0); + } + + bool operator()(const std::string& lhs, const jwt::string_view rhs) const + { + int ret = strcmp(lhs.data(), rhs.data()); + return (ret < 0); + } + + bool operator()(const jwt::string_view lhs, const std::string& rhs) const + { + int ret = strcmp(lhs.data(), rhs.data()); + return (ret < 0); + } + }; + + using header_claim_set_t = std::set; +}; + // Fwd declaration for friend functions to specify the // default arguments // See: https://stackoverflow.com/a/23336823/434233 @@ -235,7 +277,11 @@ public: // 'tors /* * Default constructor. */ - jwt_header() = default; + jwt_header() + { + payload_["alg"] = "none"; + payload_["typ"] = "JWT"; + } /** * Constructor taking specified algorithm type @@ -245,6 +291,8 @@ public: // 'tors : alg_(alg) , typ_(typ) { + payload_["typ"] = type_to_str(typ_).to_string(); + payload_["alg"] = alg_to_str(alg_).to_string(); } /** @@ -269,9 +317,10 @@ public: // Exposed APIs /** * Set the algorithm. */ - void algo(enum algorithm alg) noexcept + void algo(enum algorithm alg) { alg_ = alg; + payload_["alg"] = alg_to_str(alg_).to_string(); } /** @@ -280,6 +329,7 @@ public: // Exposed APIs void algo(const jwt::string_view sv) { alg_ = str_to_alg(sv.data()); + payload_["alg"] = alg_to_str(alg_).to_string(); } /** @@ -295,11 +345,21 @@ public: // Exposed APIs * header would not be valid after modifying the type. */ /** - * Set the JWS type. + * Set the JWT type. */ void typ(enum type typ) noexcept { typ_ = typ; + payload_["typ"] = type_to_str(typ_).to_string(); + } + + /** + * Set the JWT type header. String overload. + */ + void typ(const jwt::string_view sv) + { + typ_ = str_to_type(sv.data()); + payload_["typ"] = type_to_str(typ_).to_string(); } /** @@ -310,6 +370,38 @@ public: // Exposed APIs return typ_; } + /** + * Add a header to the JWT header. + */ + template >::value + > + > + bool add_header(const jwt::string_view hname, T&& hvalue, bool overwrite=false) + { + auto itr = headers_.find(hname); + if (itr != std::end(headers_) && !overwrite) { + return false; + } + + headers_.emplace(hname.data(), hname.length()); + payload_[hname.data()] = std::forward(hvalue); + + return true; + } + + /** + * Add a header to the JWT header. + * Overload which takes the header value as `jwt::string_view` + */ + bool add_header(const jwt::string_view cname, const jwt::string_view cvalue, bool overwrite=false) + { + return add_header(cname, + std::string{cvalue.data(), cvalue.length()}, + overwrite); + } + /** * Get the URL safe base64 encoded string * of the header. @@ -347,14 +439,9 @@ public: // Exposed APIs * @note: Presence of this member function is a requirement * for some interfaces (Eg: `write_interface`). */ - json_t create_json_obj() const + const json_t& create_json_obj() const { - json_t obj = json_t::object(); - //TODO: should be able to do with string_view - obj["typ"] = type_to_str(typ_).to_string(); - obj["alg"] = alg_to_str(alg_).to_string(); - - return obj; + return payload_; } private: // Data members @@ -363,6 +450,12 @@ private: // Data members /// The type of header enum type typ_ = type::JWT; + + // The JSON payload object + json_t payload_; + + //Extra headers for JWS + jwt_set::header_claim_set_t headers_; }; @@ -424,10 +517,6 @@ public: // Exposed APIs return false; } - if (itr != claim_names_.end() && overwrite) { - claim_names_.erase(itr); - } - // Add it to the known set of claims claim_names_.emplace(cname.data(), cname.length()); @@ -650,43 +739,11 @@ public: // Exposed APIs } private: - /** - * Transparent comparator. - * @note: C++14 only. - */ - struct case_compare - { - using is_transparent = std::true_type; - - bool operator()(const std::string& lhs, const std::string& rhs) const - { - int ret = strcmp(lhs.c_str(), rhs.c_str()); - return (ret < 0); - } - - bool operator()(const jwt::string_view lhs, const jwt::string_view rhs) const - { - int ret = strcmp(lhs.data(), rhs.data()); - return (ret < 0); - } - - bool operator()(const std::string& lhs, const jwt::string_view rhs) const - { - int ret = strcmp(lhs.data(), rhs.data()); - return (ret < 0); - } - - bool operator()(const jwt::string_view lhs, const std::string& rhs) const - { - int ret = strcmp(lhs.data(), rhs.data()); - return (ret < 0); - } - }; /// JSON object containing payload json_t payload_; /// The set of claim names in the payload - std::set claim_names_; + jwt_set::header_claim_set_t claim_names_; }; /** diff --git a/tests/test_jwt_encode b/tests/test_jwt_encode index 25084e8..2dea4df 100755 Binary files a/tests/test_jwt_encode and b/tests/test_jwt_encode differ diff --git a/tests/test_jwt_encode.cc b/tests/test_jwt_encode.cc index 41f4fe9..b876e48 100644 --- a/tests/test_jwt_encode.cc +++ b/tests/test_jwt_encode.cc @@ -259,6 +259,31 @@ TEST (EncodeTest, OverwriteClaimsTest) EXPECT_TRUE (obj.payload().has_claim_with_value("x-pld1", "1data")); } +TEST (EncodeTest, HeaderParamTest) +{ + using namespace jwt::params; + + jwt::jwt_object obj{ + headers({ + {"alg", "none"}, + {"typ", "jwt"}, + }), + payload({ + {"iss", "arun.muralidharan"}, + {"sub", "nsfw"}, + {"x-pld", "not my ex"} + }) + }; + + bool ret = obj.header().add_header("kid", 1234567); + EXPECT_TRUE (ret); + + ret = obj.header().add_header("crit", std::array{"exp"}); + EXPECT_TRUE (ret); + + std::cout << obj.header() << std::endl; +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv);