#ifndef JWT_IPP #define JWT_IPP #include "jwt/detail/meta.hpp" #include namespace jwt { /** */ static inline void jwt_throw_exception(const std::error_code& ec); template std::string to_json_str(const T& obj, bool pretty) { return pretty ? obj.create_json_obj().dump(2) : obj.create_json_obj().dump() ; } template std::ostream& write(std::ostream& os, const T& obj, bool pretty) { pretty ? (os << std::setw(2) << obj.create_json_obj()) : (os << obj.create_json_obj()) ; return os; } template {}>::type> std::ostream& operator<< (std::ostream& os, const T& obj) { os << obj.create_json_obj(); return os; } //======================================================================== void jwt_header::decode(const string_view enc_str, std::error_code& ec) noexcept { ec.clear(); std::string json_str = base64_decode(enc_str); json_t obj; try { obj = json_t::parse(std::move(json_str)); } catch(const std::exception& e) { ec = DecodeErrc::JsonParseError; return; } //Look for the algorithm field auto alg_itr = obj.find("alg"); if (alg_itr == obj.end()) { ec = DecodeErrc::AlgHeaderMiss; return; } alg_ = str_to_alg(alg_itr.value().get()); if (alg_ != algorithm::NONE) { auto itr = obj.find("typ"); if (itr == obj.end()) { ec = DecodeErrc::TypHeaderMiss; return; } const auto& typ = itr.value().get(); if (strcasecmp(typ.c_str(), "JWT")) { ec = DecodeErrc::TypMismatch; return; } typ_ = str_to_type(typ); } else { //TODO: } return; } void jwt_header::decode(const string_view enc_str) throw(DecodeError) { std::error_code ec; decode(enc_str, ec); if (ec) { throw DecodeError(ec.message()); } return; } void jwt_payload::decode(const string_view enc_str, std::error_code& ec) noexcept { ec.clear(); std::string json_str = base64_decode(enc_str); try { payload_ = json_t::parse(std::move(json_str)); } catch(const std::exception& e) { ec = DecodeErrc::JsonParseError; return; } //populate the claims set for (auto it = payload_.begin(); it != payload_.end(); ++it) { auto ret = claim_names_.insert(it.key()); if (!ret.second) { ec = DecodeErrc::DuplClaims; break; } } return; } void jwt_payload::decode(const string_view enc_str) throw(DecodeError) { std::error_code ec; decode(enc_str, ec); if (ec) { throw DecodeError(ec.message()); } return; } std::string jwt_signature::encode(const jwt_header& header, const jwt_payload& payload, std::error_code& ec) { std::string jwt_msg; ec.clear(); //TODO: Optimize allocations sign_func_t sign_fn = get_sign_algorithm_impl(header); std::string hdr_sign = header.base64_encode(); std::string pld_sign = payload.base64_encode(); std::string data = hdr_sign + '.' + pld_sign; auto res = sign_fn(key_, data); if (res.second && res.second != AlgorithmErrc::NoneAlgorithmUsed) { ec = res.second; return {}; } std::string b64hash; if (!res.second) { b64hash = base64_encode(res.first.c_str(), res.first.length()); } auto new_len = base64_uri_encode(&b64hash[0], b64hash.length()); b64hash.resize(new_len); jwt_msg = data + '.' + b64hash; return jwt_msg; } verify_result_t jwt_signature::verify(const jwt_header& header, const string_view hdr_pld_sign, const string_view jwt_sign) { verify_func_t verify_fn = get_verify_algorithm_impl(header); return verify_fn(key_, hdr_pld_sign, jwt_sign); } sign_func_t jwt_signature::get_sign_algorithm_impl(const jwt_header& hdr) const noexcept { sign_func_t ret = nullptr; switch (hdr.algo()) { case algorithm::HS256: ret = HMACSign::sign; break; case algorithm::HS384: ret = HMACSign::sign; break; case algorithm::HS512: ret = HMACSign::sign; break; case algorithm::NONE: ret = HMACSign::sign; break; case algorithm::RS256: ret = PEMSign::sign; break; case algorithm::RS384: ret = PEMSign::sign; break; case algorithm::RS512: ret = PEMSign::sign; break; case algorithm::ES256: ret = PEMSign::sign; break; case algorithm::ES384: ret = PEMSign::sign; break; case algorithm::ES512: ret = PEMSign::sign; break; default: assert (0 && "Code not reached"); }; return ret; } verify_func_t jwt_signature::get_verify_algorithm_impl(const jwt_header& hdr) const noexcept { verify_func_t ret = nullptr; switch (hdr.algo()) { case algorithm::HS256: ret = HMACSign::verify; break; case algorithm::HS384: ret = HMACSign::verify; break; case algorithm::HS512: ret = HMACSign::verify; break; case algorithm::NONE: ret = HMACSign::verify; break; case algorithm::RS256: ret = PEMSign::verify; break; case algorithm::RS384: ret = PEMSign::verify; break; case algorithm::RS512: ret = PEMSign::verify; break; case algorithm::ES256: ret = PEMSign::verify; break; case algorithm::ES384: ret = PEMSign::verify; break; case algorithm::ES512: ret = PEMSign::verify; break; default: assert (0 && "Code not reached"); }; return ret; } // template jwt_object::jwt_object(Args&&... args) { static_assert (detail::meta::are_all_params::value, "All constructor argument types must model ParameterConcept"); set_parameters(std::forward(args)...); } template void jwt_object::set_parameters( params::detail::payload_param&& payload, Rest&&... rargs) { for (const auto& elem : payload.get()) { payload_.add_claim(std::move(elem.first), std::move(elem.second)); } set_parameters(std::forward(rargs)...); } template void jwt_object::set_parameters( params::detail::secret_param secret, Rest&&... rargs) { secret_.assign(secret.get().data(), secret.get().length()); set_parameters(std::forward(rargs)...); } template void jwt_object::set_parameters( params::detail::algorithm_param alg, Rest&&... rargs) { header_.algo(alg.get()); set_parameters(std::forward(rargs)...); } template void jwt_object::set_parameters( params::detail::headers_param&& header, Rest&&... rargs) { //TODO: add kid support set_parameters(std::forward(rargs)...); } void jwt_object::set_parameters() { //sentinel call return; } template >::value> > jwt_object& jwt_object::add_claim(const string_view name, T&& value) { payload_.add_claim(name, std::forward(value)); return *this; } jwt_object& jwt_object::add_claim(const string_view name, system_time_t tp) { return add_claim( name, std::chrono::duration_cast< std::chrono::seconds>(tp.time_since_epoch()).count() ); } jwt_object& jwt_object::remove_claim(const string_view name) { payload_.remove_claim(name); return *this; } std::string jwt_object::signature(std::error_code& ec) const { ec.clear(); //key/secret should be set for any algorithm except NONE if (header().algo() != jwt::algorithm::NONE) { if (secret_.length() == 0) { ec = AlgorithmErrc::KeyNotFoundErr; return {}; } } jwt_signature jws{secret_}; return jws.encode(header_, payload_, ec); } std::string jwt_object::signature() const { std::error_code ec; std::string res = signature(ec); if (ec) { throw SigningError(ec.message()); } return res; } template std::error_code jwt_object::verify( const Params& dparams, const params::detail::algorithms_param& algos) const { std::error_code ec{}; //Verify if the algorithm set in the header //is any of the one expected by the client. auto fitr = std::find_if(algos.get().begin(), algos.get().end(), [&](const auto& elem) { return jwt::str_to_alg(elem) == header().algo(); }); if (fitr == algos.get().end()) { ec = VerificationErrc::InvalidAlgorithm; return ec; } //Check for the expiry timings if (has_claim(registered_claims::expiration)) { auto curr_time = std::chrono::duration_cast< std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count(); auto p_exp = payload() .get_claim_value(registered_claims::expiration); if (p_exp < (curr_time + dparams.leeway)) { ec = VerificationErrc::TokenExpired; return ec; } } //Check for issuer if (dparams.has_issuer && has_claim(registered_claims::issuer)) { jwt::string_view p_issuer = payload() .get_claim_value(registered_claims::issuer); if (p_issuer.data() != dparams.issuer) { ec = VerificationErrc::InvalidIssuer; return ec; } } //Check for audience if (dparams.has_aud && has_claim(registered_claims::audience)) { jwt::string_view p_aud = payload() .get_claim_value(registered_claims::audience); if (p_aud.data() != dparams.aud) { ec = VerificationErrc::InvalidAudience; return ec; } } //Check for NBF if (has_claim(registered_claims::not_before)) { auto curr_time = std::chrono::duration_cast< std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count(); auto p_exp = payload() .get_claim_value(registered_claims::not_before); if ((p_exp - dparams.leeway) < curr_time) { ec = VerificationErrc::ImmatureSignature; return ec; } } return ec; } std::array jwt_object::three_parts(const string_view enc_str) { std::array result; size_t fpos = enc_str.find_first_of('.'); assert (fpos != string_view::npos); result[0] = string_view{&enc_str[0], fpos}; size_t spos = enc_str.find_first_of('.', fpos + 1); result[1] = string_view{&enc_str[fpos + 1], spos - fpos - 1}; if (spos != enc_str.length()) { result[2] = string_view{&enc_str[spos + 1], enc_str.length() - spos - 1}; } return result; } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::leeway_param l, Rest&&... args) { dparams.leeway = l.get(); jwt_object::set_decode_params(dparams, std::forward(args)...); return; } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::verify_param v, Rest&&... args) { dparams.verify = v.get(); jwt_object::set_decode_params(dparams, std::forward(args)...); return; } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::issuer_param i, Rest&&... args) { dparams.issuer = std::move(i).get(); dparams.has_issuer = true; jwt_object::set_decode_params(dparams, std::forward(args)...); return; } template void jwt_object::set_decode_params(DecodeParams& dparams, params::detail::audience_param a, Rest&&... args) { dparams.aud = std::move(a).get(); dparams.has_aud = true; jwt_object::set_decode_params(dparams, std::forward(args)...); } template void jwt_object::set_decode_params(DecodeParams& dparams) { return; } //================================================================== template jwt_object decode(const string_view enc_str, const string_view key, const params::detail::algorithms_param& algos, std::error_code& ec, Args&&... args) { ec.clear(); jwt_object obj; if (algos.get().size() == 0) { ec = DecodeErrc::EmptyAlgoList; return obj; } struct decode_params { /// Verify parameter. Defaulted to true. bool verify = true; /// Leeway parameter. Defaulted to zero seconds. uint32_t leeway = 0; ///The issuer //TODO: optional type bool has_issuer = false; std::string issuer; ///The audience //TODO: optional type bool has_aud = false; std::string aud; }; decode_params dparams{}; jwt_object::set_decode_params(dparams, std::forward(args)...); //Signature must have atleast 2 dots auto dot_cnt = std::count_if(std::begin(enc_str), std::end(enc_str), [](char ch) { return ch == '.'; }); if (dot_cnt < 2) { ec = DecodeErrc::SignatureFormatError; return obj; } auto parts = jwt_object::three_parts(enc_str); //throws decode error jwt_header hdr{}; hdr.decode(parts[0], ec); if (ec) { return obj; } //obj.header(jwt_header{parts[0]}); obj.header(std::move(hdr)); //If the algorithm is not NONE, it must not //have more than two dots ('.') and the split //must result in three strings with some length. if (obj.header().algo() != jwt::algorithm::NONE) { if (dot_cnt > 2) { ec = DecodeErrc::SignatureFormatError; return obj; } if (parts[2].length() == 0) { ec = DecodeErrc::SignatureFormatError; return obj; } } //throws decode error jwt_payload payload{}; payload.decode(parts[1], ec); if (ec) { return obj; } obj.payload(std::move(payload)); if (dparams.verify) { ec = obj.verify(dparams, algos); if (ec) return obj; } //Verify the signature only if some algorithm was used if (obj.header().algo() != algorithm::NONE) { jwt_signature jsign{key}; // Length of the encoded header and payload only. // Addition of '1' to account for the '.' character. auto l = parts[0].length() + 1 + parts[1].length(); //MemoryAllocationError is not caught verify_result_t res = jsign.verify(obj.header(), enc_str.substr(0, l), parts[2]); if (res.second) { ec = res.second; return obj; } if (!res.first) { ec = VerificationErrc::InvalidSignature; return obj; } } else { ec = AlgorithmErrc::NoneAlgorithmUsed; } return obj; } template jwt_object decode(const string_view enc_str, const string_view key, const params::detail::algorithms_param& algos, Args&&... args) { std::error_code ec{}; auto jwt_obj = decode(enc_str, key, algos, ec, std::forward(args)...); if (ec) { jwt_throw_exception(ec); } return jwt_obj; } void jwt_throw_exception(const std::error_code& ec) { const auto& cat = ec.category(); if (&cat == &theVerificationErrorCategory) { switch (static_cast(ec.value())) { case VerificationErrc::InvalidAlgorithm: { throw InvalidAlgorithmError(ec.message()); } case VerificationErrc::TokenExpired: { throw TokenExpiredError(ec.message()); } case VerificationErrc::InvalidIssuer: { throw InvalidIssuerError(ec.message()); } case VerificationErrc::InvalidAudience: { throw InvalidAudienceError(ec.message()); } case VerificationErrc::ImmatureSignature: { throw ImmatureSignatureError(ec.message()); } case VerificationErrc::InvalidSignature: { throw InvalidSignatureError(ec.message()); } default: assert (0 && "Unknown error code"); }; } if (&cat == &theDecodeErrorCategory) { switch (static_cast(ec.value())) { case DecodeErrc::SignatureFormatError: { throw SignatureFormatError(ec.message()); } default: { throw DecodeError(ec.message()); } }; assert (0 && "Unknown error code"); } if (&cat == &theAlgorithmErrCategory) { switch (static_cast(ec.value())) { case AlgorithmErrc::VerificationErr: { throw InvalidSignatureError(ec.message()); } case AlgorithmErrc::NoneAlgorithmUsed: { //Not an error actually. break; } default: assert (0 && "Unknown error code or not to be treated as an error"); }; } return; } } // END namespace jwt #endif