diff --git a/include/jwt/error_codes.hpp b/include/jwt/error_codes.hpp index 4dacf82..08d2b7d 100644 --- a/include/jwt/error_codes.hpp +++ b/include/jwt/error_codes.hpp @@ -33,6 +33,17 @@ enum class DecodeErrc DuplClaims, }; +/** + */ +enum class VerificationErrc +{ + InvalidAlgorithm = 1, + TokenExpired, + InvalidIssuer, + InvalidAudience, + ImmatureSignature, +}; + /** */ std::error_code make_error_code(AlgorithmErrc err); @@ -41,6 +52,10 @@ std::error_code make_error_code(AlgorithmErrc err); */ std::error_code make_error_code(DecodeErrc err); +/** + */ +std::error_code make_error_code(VerificationErrc err); + } // END namespace jwt @@ -55,6 +70,9 @@ namespace std template <> struct is_error_code_enum: true_type {}; + + template <> + struct is_error_code_enum: true_type {}; } #include "jwt/impl/error_codes.ipp" diff --git a/include/jwt/impl/error_codes.ipp b/include/jwt/impl/error_codes.ipp index fcf9329..672967c 100644 --- a/include/jwt/impl/error_codes.ipp +++ b/include/jwt/impl/error_codes.ipp @@ -59,11 +59,42 @@ struct DecodeErrorCategory: std::error_category } }; +/** + */ +struct VerificationErrorCategory: std::error_category +{ + const char* name() const noexcept override + { + return "verification"; + } + + std::string message(int ev) const override + { + switch (static_cast(ev)) + { + case VerificationErrc::InvalidAlgorithm: + return "invalid algorithm"; + case VerificationErrc::TokenExpired: + return "token expired"; + case VerificationErrc::InvalidIssuer: + return "invalid issuer"; + case VerificationErrc::InvalidAudience: + return "invalid audience"; + case VerificationErrc::ImmatureSignature: + return "immature signature"; + }; + + assert (0 && "Code not reached"); + } +}; + // Create global object for the error categories const AlgorithmErrCategory theAlgorithmErrCategory {}; const DecodeErrorCategory theDecodeErrorCategory {}; +const VerificationErrorCategory theVerificationErrorCategory {}; + } @@ -78,6 +109,10 @@ std::error_code make_error_code(DecodeErrc err) return { static_cast(err), theDecodeErrorCategory }; } +std::error_code make_error_code(VerificationErrc err) +{ + return { static_cast(err), theVerificationErrorCategory }; +} } // END namespace jwt diff --git a/include/jwt/impl/jwt.ipp b/include/jwt/impl/jwt.ipp index 65e770f..a5f31c8 100644 --- a/include/jwt/impl/jwt.ipp +++ b/include/jwt/impl/jwt.ipp @@ -330,6 +330,71 @@ std::string jwt_object::signature() const return jws.encode(header_, payload_); } +template +std::error_code jwt_object::verify( + const Params& dparams, + const params::detail::algorithms_param& algos) const +{ + std::error_code ec{}; + + //Verify if the algorithm set in the header + //is any of the one expected by the client. + auto fitr = std::find_if(algos.get().begin(), + algos.get().end(), + [&](const auto& elem) + { + return jwt::str_to_alg(elem) == header().algo(); + }); + + if (fitr == algos.get().end()) { + ec = VerificationErrc::InvalidAlgorithm; + return ec; + } + + //Check for the expiry timings + if (has_claim(registered_claims::expiration)) { + auto curr_time = + std::chrono::duration_cast< + std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count(); + + auto p_exp = payload() + .get_claim_value(registered_claims::expiration); + + if (p_exp < (curr_time + dparams.leeway)) { + ec = VerificationErrc::TokenExpired; + return ec; + } + } + + //Check for issuer + if (dparams.has_issuer && + has_claim(registered_claims::issuer)) + { + jwt::string_view p_issuer = payload() + .get_claim_value(registered_claims::issuer); + + if (p_issuer.data() != dparams.issuer) { + ec = VerificationErrc::InvalidIssuer; + return ec; + } + } + + //Check for audience + if (dparams.has_aud && + has_claim(registered_claims::audience)) + { + jwt::string_view p_aud = payload() + .get_claim_value(registered_claims::audience); + + if (p_aud.data() != dparams.aud) { + ec = VerificationErrc::InvalidAudience; + return ec; + } + } + + return ec; +} + std::array jwt_object::three_parts(const string_view enc_str) @@ -349,7 +414,7 @@ jwt_object::three_parts(const string_view enc_str) result[1] = string_view{&enc_str[fpos + 1], spos - fpos - 1}; if (spos != enc_str.length()) { - result[2] = string_view{&enc_str[spos + 1], enc_str.length() - spos}; + result[2] = string_view{&enc_str[spos + 1], enc_str.length() - spos - 1}; } return result; @@ -376,6 +441,23 @@ void set_decode_params(DecodeParams& dparams, params::detail::verify_param v, Re return; } +template +void set_decode_params(DecodeParams& dparams, params::detail::issuer_param i, Rest&&... args) +{ + dparams.issuer = std::move(i).get(); + dparams.has_issuer = true; + set_decode_params(dparams, std::forward(args)...); + return; +} + +template +void set_decode_params(DecodeParams& dparams, params::detail::audience_param a, Rest&&... args) +{ + dparams.aud = std::move(a).get(); + dparams.has_aud = true; + set_decode_params(dparams, std::forward(args)...); +} + template void set_decode_params(DecodeParams& dparams) { @@ -400,6 +482,14 @@ jwt_object decode(const string_view enc_str, bool verify = true; /// Leeway parameter. Defaulted to zero seconds. uint32_t leeway = 0; + ///The issuer + //TODO: optional type + bool has_issuer = false; + std::string issuer; + ///The audience + //TODO: optional type + bool has_aud = false; + std::string aud; }; decode_params dparams{}; @@ -414,31 +504,8 @@ jwt_object decode(const string_view enc_str, //throws decode error obj.payload(jwt_payload{parts[1]}); - //TODO: Should be part of jwt_object::verify if (dparams.verify) { - //Verify if the algorithm set in the header - //is any of the one expected by the client. - auto fitr = std::find_if(algos.get().begin(), algos.get().end(), - [&](const auto& elem) { - return jwt::str_to_alg(elem) == obj.header().algo(); - }); - - if (fitr == algos.get().end()) { - throw VerificationError("Provided algorithms do not match with header"); - } - - //Check for the expiry timings - if (obj.payload().has_claim("exp")) { - auto curr_time = std::chrono::duration_cast< - std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count(); - auto p_exp = obj.payload() - .get_claim_value(registered_claims::expiration) - .get(); - - if (p_exp < (curr_time + dparams.leeway)) { - throw VerificationError("Token expired"); - } - } + std::error_code ec = obj.verify(dparams, algos); } jwt_signature jsign{key}; diff --git a/include/jwt/jwt.hpp b/include/jwt/jwt.hpp index 6ed9735..2a9ba86 100644 --- a/include/jwt/jwt.hpp +++ b/include/jwt/jwt.hpp @@ -358,16 +358,18 @@ public: // Exposed APIs /** */ + template decltype(auto) get_claim_value(const string_view cname) const { - return payload_[cname.data()]; + return payload_[cname.data()].get(); } /** */ + template decltype(auto) get_claim_value(enum registered_claims cname) const { - return get_claim_value(reg_claims_to_str(cname)); + return get_claim_value(reg_claims_to_str(cname)); } /** @@ -648,10 +650,31 @@ public: // Exposed APIs return remove_claim(reg_claims_to_str(cname)); } + /** + */ + bool has_claim(const string_view cname) const noexcept + { + return payload().has_claim(cname); + } + + /** + */ + bool has_claim(enum registered_claims cname) const noexcept + { + return payload().has_claim(cname); + } + /** */ std::string signature() const; + /** + */ + template + std::error_code verify( + const Params& dparams, + const params::detail::algorithms_param& algos) const; + private: // private APIs /** */ diff --git a/include/jwt/parameters.hpp b/include/jwt/parameters.hpp index 1305214..60ecf52 100644 --- a/include/jwt/parameters.hpp +++ b/include/jwt/parameters.hpp @@ -2,6 +2,7 @@ #define CPP_JWT_PARAMETERS_HPP #include +#include #include #include #include @@ -143,6 +144,34 @@ struct leeway_param uint32_t leeway_; }; +/** + */ +struct audience_param +{ + audience_param(std::string aud) + : aud_(std::move(aud)) + {} + + const std::string& get() const& noexcept { return aud_; } + std::string get() && noexcept { return aud_; } + + std::string aud_; +}; + +/** + */ +struct issuer_param +{ + issuer_param(std::string iss) + : iss_(std::move(iss)) + {} + + const std::string& get() const& noexcept { return iss_; } + std::string get() && noexcept { return iss_; } + + std::string iss_; +}; + } // END namespace detail // Useful typedef diff --git a/include/jwt/test/test_jwt_object b/include/jwt/test/test_jwt_object index b5ce73d..ffa938e 100755 Binary files a/include/jwt/test/test_jwt_object and b/include/jwt/test/test_jwt_object differ diff --git a/include/jwt/test/test_jwt_object.cc b/include/jwt/test/test_jwt_object.cc index 69deda1..29f4d85 100644 --- a/include/jwt/test/test_jwt_object.cc +++ b/include/jwt/test/test_jwt_object.cc @@ -79,14 +79,16 @@ MIGkAgEBBDBeLCgapjZmvTatMHaYX3A02+0Ys3Tr8kda+E9DFnmCSiCOEig519fT ; std::cout << "pem sign " << obj.signature() << std::endl; - std::cout << "Get claim value for exp: " << obj.payload().get_claim_value("exp") << std::endl; + std::cout << "Get claim value for exp: " << + obj.payload().get_claim_value("exp") << std::endl; + sleep(4); auto dec_obj = jwt::decode(obj.signature(), pub_key, algorithms({"es256"})); std::cout << dec_obj.payload() << std::endl; } int main() { - //basic_jwt_object_test(); - jwt_object_pem_test(); + basic_jwt_object_test(); + //jwt_object_pem_test(); return 0; }