#ifndef JWT_IPP #define JWT_IPP #include "jwt/detail/meta.hpp" namespace jwt { template std::string to_json_str(const T& obj, bool pretty) { return pretty ? obj.create_json_obj().dump(2) : obj.create_json_obj().dump() ; } template std::ostream& write(std::ostream& os, const T& obj, bool pretty) { pretty ? (os << std::setw(2) << obj.create_json_obj()) : (os << obj.create_json_obj()) ; return os; } template {}>::type> std::ostream& operator<< (std::ostream& os, const T& obj) { os << obj.create_json_obj(); return os; } //======================================================================== void jwt_header::decode(const string_view enc_str, std::error_code& ec) noexcept { ec.clear(); std::string json_str = base64_decode(enc_str); json_t obj; try { obj = json_t::parse(std::move(json_str)); } catch(const std::exception& e) { ec = DecodeErrc::JsonParseError; return; } //Look for the algorithm field auto alg_itr = obj.find("alg"); if (alg_itr == obj.end()) { ec = DecodeErrc::AlgHeaderMiss; return; } alg_ = str_to_alg(alg_itr.value().get()); if (alg_ != algorithm::NONE) { auto itr = obj.find("typ"); if (itr == obj.end()) { ec = DecodeErrc::TypHeaderMiss; return; } const auto& typ = itr.value().get(); if (strcasecmp(typ.c_str(), "JWT")) { ec = DecodeErrc::TypMismatch; return; } typ_ = str_to_type(typ); } else { //TODO: } return; } void jwt_header::decode(const string_view enc_str) throw(DecodeError) { std::error_code ec; decode(enc_str, ec); if (ec) { throw DecodeError(ec.message()); } return; } void jwt_payload::decode(const string_view enc_str, std::error_code& ec) noexcept { ec.clear(); std::string json_str = base64_decode(enc_str); try { payload_ = json_t::parse(std::move(json_str)); } catch(const std::exception& e) { ec = DecodeErrc::JsonParseError; return; } return; } void jwt_payload::decode(const string_view enc_str) throw(DecodeError) { std::error_code ec; decode(enc_str, ec); if (ec) { throw DecodeError(ec.message()); } return; } std::string jwt_signature::encode(const jwt_header& header, const jwt_payload& payload) { std::string jwt_msg; //TODO: Optimize allocations sign_func_t sign_fn = get_sign_algorithm_impl(header); std::string hdr_sign = header.base64_encode(); std::string pld_sign = payload.base64_encode(); std::string data = hdr_sign + '.' + pld_sign; auto res = sign_fn(key_, data); if (res.second) { std::cout << res.second.message() << std::endl; return {}; } std::string b64hash = base64_encode(res.first.c_str(), res.first.length()); auto new_len = base64_uri_encode(&b64hash[0], b64hash.length()); b64hash.resize(new_len); jwt_msg = data + '.' + b64hash; return jwt_msg; } bool jwt_signature::verify(const jwt_header& header, const string_view hdr_pld_sign, const string_view jwt_sign) { //TODO: is bool the right choice ? verify_func_t verify_fn = get_verify_algorithm_impl(header); verify_fn(key_, hdr_pld_sign, jwt_sign); return true; } sign_func_t jwt_signature::get_sign_algorithm_impl(const jwt_header& hdr) const noexcept { sign_func_t ret = nullptr; switch (hdr.algo()) { case algorithm::HS256: ret = HMACSign::sign; break; case algorithm::HS384: ret = HMACSign::sign; break; case algorithm::HS512: ret = HMACSign::sign; break; case algorithm::NONE: ret = HMACSign::sign; break; case algorithm::RS256: ret = PEMSign::sign; break; case algorithm::RS384: ret = PEMSign::sign; break; case algorithm::RS512: ret = PEMSign::sign; break; case algorithm::ES256: ret = PEMSign::sign; break; case algorithm::ES384: ret = PEMSign::sign; break; case algorithm::ES512: ret = PEMSign::sign; break; default: assert (0 && "Code not reached"); }; return ret; } verify_func_t jwt_signature::get_verify_algorithm_impl(const jwt_header& hdr) const noexcept { verify_func_t ret = nullptr; switch (hdr.algo()) { case algorithm::HS256: ret = HMACSign::verify; break; case algorithm::HS384: ret = HMACSign::verify; break; case algorithm::HS512: ret = HMACSign::verify; break; case algorithm::NONE: ret = HMACSign::verify; break; case algorithm::RS256: ret = PEMSign::verify; break; case algorithm::RS384: ret = PEMSign::verify; break; case algorithm::RS512: ret = PEMSign::verify; break; case algorithm::ES256: ret = PEMSign::verify; break; case algorithm::ES384: ret = PEMSign::verify; break; case algorithm::ES512: ret = PEMSign::verify; break; default: assert (0 && "Code not reached"); }; return ret; } // template jwt_object::jwt_object(Args&&... args) { static_assert (detail::meta::are_all_params::value, "All constructor argument types must model ParameterConcept"); set_parameters(std::forward(args)...); } template void jwt_object::set_parameters( params::detail::payload_param&& payload, Rest&&... rargs) { for (const auto& elem : payload.get()) { payload_.add_claim(std::move(elem.first), std::move(elem.second)); } set_parameters(std::forward(rargs)...); } template void jwt_object::set_parameters( params::detail::secret_param secret, Rest&&... rargs) { secret_.assign(secret.get().data(), secret.get().length()); set_parameters(std::forward(rargs)...); } template void jwt_object::set_parameters( params::detail::algorithm_param alg, Rest&&... rargs) { header_.algo(alg.get()); set_parameters(std::forward(rargs)...); } template void jwt_object::set_parameters( params::detail::headers_param&& header, Rest&&... rargs) { //TODO: add kid support set_parameters(std::forward(rargs)...); } void jwt_object::set_parameters() { //setinel call return; } template >::value> > jwt_object& jwt_object::add_claim(const string_view name, T&& value) { payload_.add_claim(name, std::forward(value)); return *this; } jwt_object& jwt_object::add_claim(const string_view name, system_time_t tp) { return add_claim( name, std::chrono::duration_cast< std::chrono::seconds>(tp.time_since_epoch()).count() ); } jwt_object& jwt_object::remove_claim(const string_view name) { payload_.remove_claim(name); return *this; } std::string jwt_object::signature() const { jwt_signature jws{secret_}; return jws.encode(header_, payload_); } std::array jwt_object::three_parts(const string_view enc_str) { std::array result; size_t fpos = enc_str.find_first_of('.'); assert (fpos != string_view::npos); result[0] = string_view{&enc_str[0], fpos}; size_t spos = enc_str.find_first_of('.', fpos + 1); if (spos == string_view::npos) { //TODO: Check for none algorithm } result[1] = string_view{&enc_str[fpos + 1], spos - fpos - 1}; if (spos != enc_str.length()) { result[2] = string_view{&enc_str[spos + 1], enc_str.length() - spos}; } return result; } //==================================================================== jwt_object jwt_decode(const string_view encoded_str, const string_view key, bool validate) { //TODO: implement error_code jwt_object jobj; auto parts = jwt_object::three_parts(encoded_str); //throws decode error jobj.header(jwt_header{parts[0]}); //throws decode error jobj.payload(jwt_payload{parts[1]}); jwt_signature jsign{key}; //length of the encoded header and payload only. //Addition of '1' to account for the '.' character. auto l = parts[0].length() + 1 + parts[1].length(); auto res = jsign.verify(jobj.header(), encoded_str.substr(0, l), parts[2]); return jobj; } } // END namespace jwt #endif