add basic discord support

This commit is contained in:
2024-11-09 01:58:06 +03:00
parent c3e7a9c92d
commit 8965b7ee90
869 changed files with 191278 additions and 7 deletions

View File

@@ -0,0 +1,57 @@
set(CURRENT_LIB_NAME hpke)
###
### Library Config
###
file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
# -Werror=dangling-reference
add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES})
add_dependencies(${CURRENT_LIB_NAME} bytes tls_syntax)
target_include_directories(${CURRENT_LIB_NAME}
PRIVATE
"${JSON_INCLUDE_INTERFACE}")
target_link_libraries(${CURRENT_LIB_NAME}
PUBLIC
bytes tls_syntax
)
target_include_directories(${CURRENT_LIB_NAME}
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include/${PROJECT_NAME}>
PRIVATE
${OPENSSL_INCLUDE_DIR}
)
# Private statically linked dependencies
target_link_libraries(${CURRENT_LIB_NAME} PRIVATE
mlspp
mls_vectors
bytes
tls_syntax
)
target_compile_options(
"${CURRENT_LIB_NAME}"
PUBLIC
"$<$<PLATFORM_ID:Windows>:/bigobj;/Zc:preprocessor>"
PRIVATE
"$<$<PLATFORM_ID:Windows>:$<$<CONFIG:Debug>:/sdl;/Od;/DEBUG;/MP;/DFD_SETSIZE=1024>>"
"$<$<PLATFORM_ID:Windows>:$<$<CONFIG:Release>:/O2;/Oi;/Oy;/GL;/Gy;/sdl;/MP;/DFD_SETSIZE=1024>>"
"$<$<PLATFORM_ID:Linux>:$<$<CONFIG:Debug>:-Wall;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-g;-Og;-fPIC>>"
"$<$<PLATFORM_ID:Linux>:$<$<CONFIG:Release>:-Wall;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-O3;-fPIC>>"
"${AVX_FLAG}"
)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include")
# For nlohmann/json.hpp
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../../include")

View File

@@ -0,0 +1,20 @@
#pragma once
#include <bytes/bytes.h>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
std::string
to_base64(const bytes& data);
std::string
to_base64url(const bytes& data);
bytes
from_base64(const std::string& enc);
bytes
from_base64url(const std::string& enc);
} // namespace mlspp::hpke

View File

@@ -0,0 +1,75 @@
#pragma once
#include <memory>
#include <optional>
#include <bytes/bytes.h>
#include <chrono>
#include <hpke/signature.h>
#include <map>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
struct Certificate
{
private:
struct ParsedCertificate;
std::unique_ptr<ParsedCertificate> parsed_cert;
public:
struct NameType
{
static const int organization;
static const int common_name;
static const int organizational_unit;
static const int country;
static const int serial_number;
static const int state_or_province_name;
};
using ParsedName = std::map<int, std::string>;
// Certificate Expiration Status
enum struct ExpirationStatus
{
inactive, // now < notBefore
active, // notBefore < now < notAfter
expired, // notAfter < now
};
explicit Certificate(const bytes& der);
explicit Certificate(std::unique_ptr<ParsedCertificate>&& parsed_cert_in);
Certificate() = delete;
Certificate(const Certificate& other);
~Certificate();
static std::vector<Certificate> parse_pem(const bytes& pem);
bool valid_from(const Certificate& parent) const;
// Accessors for parsed certificate elements
uint64_t issuer_hash() const;
uint64_t subject_hash() const;
ParsedName issuer() const;
ParsedName subject() const;
bool is_ca() const;
ExpirationStatus expiration_status() const;
std::optional<bytes> subject_key_id() const;
std::optional<bytes> authority_key_id() const;
std::vector<std::string> email_addresses() const;
std::vector<std::string> dns_names() const;
bytes hash() const;
std::chrono::system_clock::time_point not_before() const;
std::chrono::system_clock::time_point not_after() const;
Signature::ID public_key_algorithm() const;
Signature::ID signature_algorithm() const;
const std::unique_ptr<Signature::PublicKey> public_key;
const bytes raw;
};
bool
operator==(const Certificate& lhs, const Certificate& rhs);
} // namespace mlspp::hpke

View File

@@ -0,0 +1,37 @@
#pragma once
#include <memory>
#include <bytes/bytes.h>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
struct Digest
{
enum struct ID
{
SHA256,
SHA384,
SHA512,
};
template<ID id>
static const Digest& get();
const ID id;
bytes hash(const bytes& data) const;
bytes hmac(const bytes& key, const bytes& data) const;
const size_t hash_size;
private:
explicit Digest(ID id);
bytes hmac_for_hkdf_extract(const bytes& key, const bytes& data) const;
friend struct HKDF;
};
} // namespace mlspp::hpke

View File

@@ -0,0 +1,253 @@
#pragma once
#include <memory>
#include <optional>
#include <bytes/bytes.h>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
struct KEM
{
enum struct ID : uint16_t
{
DHKEM_P256_SHA256 = 0x0010,
DHKEM_P384_SHA384 = 0x0011,
DHKEM_P521_SHA512 = 0x0012,
DHKEM_X25519_SHA256 = 0x0020,
#if !defined(WITH_BORINGSSL)
DHKEM_X448_SHA512 = 0x0021,
#endif
};
template<KEM::ID>
static const KEM& get();
virtual ~KEM() = default;
struct PublicKey
{
virtual ~PublicKey() = default;
};
struct PrivateKey
{
virtual ~PrivateKey() = default;
virtual std::unique_ptr<PublicKey> public_key() const = 0;
};
const ID id;
const size_t secret_size;
const size_t enc_size;
const size_t pk_size;
const size_t sk_size;
virtual std::unique_ptr<PrivateKey> generate_key_pair() const = 0;
virtual std::unique_ptr<PrivateKey> derive_key_pair(
const bytes& ikm) const = 0;
virtual bytes serialize(const PublicKey& pk) const = 0;
virtual std::unique_ptr<PublicKey> deserialize(const bytes& enc) const = 0;
virtual bytes serialize_private(const PrivateKey& sk) const;
virtual std::unique_ptr<PrivateKey> deserialize_private(
const bytes& skm) const;
// (shared_secret, enc)
virtual std::pair<bytes, bytes> encap(const PublicKey& pkR) const = 0;
virtual bytes decap(const bytes& enc, const PrivateKey& skR) const = 0;
// (shared_secret, enc)
virtual std::pair<bytes, bytes> auth_encap(const PublicKey& pkR,
const PrivateKey& skS) const;
virtual bytes auth_decap(const bytes& enc,
const PublicKey& pkS,
const PrivateKey& skR) const;
protected:
KEM(ID id_in,
size_t secret_size_in,
size_t enc_size_in,
size_t pk_size_in,
size_t sk_size_in);
};
struct KDF
{
enum struct ID : uint16_t
{
HKDF_SHA256 = 0x0001,
HKDF_SHA384 = 0x0002,
HKDF_SHA512 = 0x0003,
};
template<KDF::ID id>
static const KDF& get();
virtual ~KDF() = default;
const ID id;
const size_t hash_size;
virtual bytes extract(const bytes& salt, const bytes& ikm) const = 0;
virtual bytes expand(const bytes& prk,
const bytes& info,
size_t size) const = 0;
bytes labeled_extract(const bytes& suite_id,
const bytes& salt,
const bytes& label,
const bytes& ikm) const;
bytes labeled_expand(const bytes& suite_id,
const bytes& prk,
const bytes& label,
const bytes& info,
size_t size) const;
protected:
KDF(ID id_in, size_t hash_size_in);
};
struct AEAD
{
enum struct ID : uint16_t
{
AES_128_GCM = 0x0001,
AES_256_GCM = 0x0002,
CHACHA20_POLY1305 = 0x0003,
// Reserved identifier for pseudo-AEAD on contexts that only allow export
export_only = 0xffff,
};
template<AEAD::ID id>
static const AEAD& get();
virtual ~AEAD() = default;
const ID id;
const size_t key_size;
const size_t nonce_size;
virtual bytes seal(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& pt) const = 0;
virtual std::optional<bytes> open(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& ct) const = 0;
protected:
AEAD(ID id_in, size_t key_size_in, size_t nonce_size_in);
};
struct Context
{
bytes do_export(const bytes& exporter_context, size_t size) const;
protected:
bytes suite;
bytes key;
bytes nonce;
bytes exporter_secret;
const KDF& kdf;
const AEAD& aead;
bytes current_nonce() const;
void increment_seq();
private:
uint64_t seq;
Context(bytes suite_in,
bytes key_in,
bytes nonce_in,
bytes exporter_secret_in,
const KDF& kdf_in,
const AEAD& aead_in);
friend struct HPKE;
friend struct HPKETest;
friend bool operator==(const Context& lhs, const Context& rhs);
};
struct SenderContext : public Context
{
SenderContext(Context&& c);
bytes seal(const bytes& aad, const bytes& pt);
};
struct ReceiverContext : public Context
{
ReceiverContext(Context&& c);
std::optional<bytes> open(const bytes& aad, const bytes& ct);
};
struct HPKE
{
enum struct Mode : uint8_t
{
base = 0,
psk = 1,
auth = 2,
auth_psk = 3,
};
HPKE(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id);
using SenderInfo = std::pair<bytes, SenderContext>;
SenderInfo setup_base_s(const KEM::PublicKey& pkR, const bytes& info) const;
ReceiverContext setup_base_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info) const;
SenderInfo setup_psk_s(const KEM::PublicKey& pkR,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const;
ReceiverContext setup_psk_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const;
SenderInfo setup_auth_s(const KEM::PublicKey& pkR,
const bytes& info,
const KEM::PrivateKey& skS) const;
ReceiverContext setup_auth_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const KEM::PublicKey& pkS) const;
SenderInfo setup_auth_psk_s(const KEM::PublicKey& pkR,
const bytes& info,
const bytes& psk,
const bytes& psk_id,
const KEM::PrivateKey& skS) const;
ReceiverContext setup_auth_psk_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const bytes& psk,
const bytes& psk_id,
const KEM::PublicKey& pkS) const;
bytes suite;
const KEM& kem;
const KDF& kdf;
const AEAD& aead;
private:
static bool verify_psk_inputs(Mode mode,
const bytes& psk,
const bytes& psk_id);
Context key_schedule(Mode mode,
const bytes& shared_secret,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const;
};
} // namespace mlspp::hpke

View File

@@ -0,0 +1,11 @@
#pragma once
#include <bytes/bytes.h>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
bytes
random_bytes(size_t size);
} // namespace mlspp::hpke

View File

@@ -0,0 +1,89 @@
#pragma once
#include <memory>
#include <bytes/bytes.h>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
struct Signature
{
enum struct ID
{
P256_SHA256,
P384_SHA384,
P521_SHA512,
Ed25519,
#if !defined(WITH_BORINGSSL)
Ed448,
#endif
RSA_SHA256,
RSA_SHA384,
RSA_SHA512,
};
template<Signature::ID id>
static const Signature& get();
virtual ~Signature() = default;
struct PublicKey
{
virtual ~PublicKey() = default;
};
struct PrivateKey
{
virtual ~PrivateKey() = default;
virtual std::unique_ptr<PublicKey> public_key() const = 0;
};
const ID id;
virtual std::unique_ptr<PrivateKey> generate_key_pair() const = 0;
virtual std::unique_ptr<PrivateKey> derive_key_pair(
const bytes& ikm) const = 0;
virtual bytes serialize(const PublicKey& pk) const = 0;
virtual std::unique_ptr<PublicKey> deserialize(const bytes& enc) const = 0;
virtual bytes serialize_private(const PrivateKey& sk) const = 0;
virtual std::unique_ptr<PrivateKey> deserialize_private(
const bytes& skm) const = 0;
struct PrivateJWK
{
const Signature& sig;
std::optional<std::string> key_id;
std::unique_ptr<PrivateKey> key;
};
static PrivateJWK parse_jwk_private(const std::string& jwk_json);
struct PublicJWK
{
const Signature& sig;
std::optional<std::string> key_id;
std::unique_ptr<PublicKey> key;
};
static PublicJWK parse_jwk(const std::string& jwk_json);
virtual std::unique_ptr<PrivateKey> import_jwk_private(
const std::string& jwk_json) const = 0;
virtual std::unique_ptr<PublicKey> import_jwk(
const std::string& jwk_json) const = 0;
virtual std::string export_jwk_private(const PrivateKey& env) const = 0;
virtual std::string export_jwk(const PublicKey& env) const = 0;
virtual bytes sign(const bytes& data, const PrivateKey& sk) const = 0;
virtual bool verify(const bytes& data,
const bytes& sig,
const PublicKey& pk) const = 0;
static std::unique_ptr<PrivateKey> generate_rsa(size_t bits);
protected:
Signature(ID id_in);
};
} // namespace mlspp::hpke

View File

@@ -0,0 +1,82 @@
#pragma once
#include <memory>
#include <optional>
#include <bytes/bytes.h>
#include <chrono>
#include <hpke/signature.h>
#include <map>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
struct UserInfoClaimsAddress
{
std::optional<std::string> formatted;
std::optional<std::string> street_address;
std::optional<std::string> locality;
std::optional<std::string> region;
std::optional<std::string> postal_code;
std::optional<std::string> country;
};
struct UserInfoClaims
{
std::optional<std::string> sub;
std::optional<std::string> name;
std::optional<std::string> given_name;
std::optional<std::string> family_name;
std::optional<std::string> middle_name;
std::optional<std::string> nickname;
std::optional<std::string> preferred_username;
std::optional<std::string> profile;
std::optional<std::string> picture;
std::optional<std::string> website;
std::optional<std::string> email;
std::optional<bool> email_verified;
std::optional<std::string> gender;
std::optional<std::string> birthdate;
std::optional<std::string> zoneinfo;
std::optional<std::string> locale;
std::optional<std::string> phone_number;
std::optional<bool> phone_number_verified;
std::optional<UserInfoClaimsAddress> address;
std::optional<uint64_t> updated_at;
static UserInfoClaims from_json(const std::string& cred_subject);
};
struct UserInfoVC
{
private:
struct ParsedCredential;
std::shared_ptr<ParsedCredential> parsed_cred;
public:
explicit UserInfoVC(std::string jwt);
UserInfoVC() = default;
UserInfoVC(const UserInfoVC& other) = default;
~UserInfoVC() = default;
UserInfoVC& operator=(const UserInfoVC& other) = default;
UserInfoVC& operator=(UserInfoVC&& other) = default;
const Signature& signature_algorithm() const;
std::string issuer() const;
std::optional<std::string> key_id() const;
std::chrono::system_clock::time_point not_before() const;
std::chrono::system_clock::time_point not_after() const;
const std::string& raw_credential() const;
const UserInfoClaims& subject() const;
const Signature::PublicJWK& public_key() const;
bool valid_from(const Signature::PublicKey& issuer_key) const;
std::string raw;
};
bool
operator==(const UserInfoVC& lhs, const UserInfoVC& rhs);
} // namespace mlspp::hpke

View File

@@ -0,0 +1,321 @@
#include "aead_cipher.h"
#include "openssl_common.h"
#include <openssl/evp.h>
#if WITH_BORINGSSL
#include <openssl/aead.h>
#endif
namespace mlspp::hpke {
///
/// ExportOnlyCipher
///
bytes
ExportOnlyCipher::seal(const bytes& /* key */,
const bytes& /* nonce */,
const bytes& /* aad */,
const bytes& /* pt */) const
{
throw std::runtime_error("seal() on export-only context");
}
std::optional<bytes>
ExportOnlyCipher::open(const bytes& /* key */,
const bytes& /* nonce */,
const bytes& /* aad */,
const bytes& /* ct */) const
{
throw std::runtime_error("open() on export-only context");
}
ExportOnlyCipher::ExportOnlyCipher()
: AEAD(AEAD::ID::export_only, 0, 0)
{
}
///
/// AEADCipher
///
AEADCipher
make_aead(AEAD::ID cipher_in)
{
return { cipher_in };
}
template<>
const AEADCipher&
AEADCipher::get<AEAD::ID::AES_128_GCM>()
{
static const auto instance = make_aead(AEAD::ID::AES_128_GCM);
return instance;
}
template<>
const AEADCipher&
AEADCipher::get<AEAD::ID::AES_256_GCM>()
{
static const auto instance = make_aead(AEAD::ID::AES_256_GCM);
return instance;
}
template<>
const AEADCipher&
AEADCipher::get<AEAD::ID::CHACHA20_POLY1305>()
{
static const auto instance = make_aead(AEAD::ID::CHACHA20_POLY1305);
return instance;
}
static size_t
cipher_key_size(AEAD::ID cipher)
{
switch (cipher) {
case AEAD::ID::AES_128_GCM:
return 16;
case AEAD::ID::AES_256_GCM:
case AEAD::ID::CHACHA20_POLY1305:
return 32;
default:
throw std::runtime_error("Unsupported algorithm");
}
}
static size_t
cipher_nonce_size(AEAD::ID cipher)
{
switch (cipher) {
case AEAD::ID::AES_128_GCM:
case AEAD::ID::AES_256_GCM:
case AEAD::ID::CHACHA20_POLY1305:
return 12;
default:
throw std::runtime_error("Unsupported algorithm");
}
}
static size_t
cipher_tag_size(AEAD::ID cipher)
{
switch (cipher) {
case AEAD::ID::AES_128_GCM:
case AEAD::ID::AES_256_GCM:
case AEAD::ID::CHACHA20_POLY1305:
return 16;
default:
throw std::runtime_error("Unsupported algorithm");
}
}
#if WITH_BORINGSSL
static const EVP_AEAD*
boringssl_cipher(AEAD::ID cipher)
{
switch (cipher) {
case AEAD::ID::AES_128_GCM:
return EVP_aead_aes_128_gcm();
case AEAD::ID::AES_256_GCM:
return EVP_aead_aes_256_gcm();
case AEAD::ID::CHACHA20_POLY1305:
return EVP_aead_chacha20_poly1305();
default:
throw std::runtime_error("Unsupported algorithm");
}
}
#else
static const EVP_CIPHER*
openssl_cipher(AEAD::ID cipher)
{
switch (cipher) {
case AEAD::ID::AES_128_GCM:
return EVP_aes_128_gcm();
case AEAD::ID::AES_256_GCM:
return EVP_aes_256_gcm();
case AEAD::ID::CHACHA20_POLY1305:
return EVP_chacha20_poly1305();
default:
throw std::runtime_error("Unsupported algorithm");
}
}
#endif // WITH_BORINGSSL
AEADCipher::AEADCipher(AEAD::ID id_in)
: AEAD(id_in, cipher_key_size(id_in), cipher_nonce_size(id_in))
, tag_size(cipher_tag_size(id))
{
}
bytes
AEADCipher::seal(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& pt) const
{
#if WITH_BORINGSSL
auto ctx = make_typed_unique(
EVP_AEAD_CTX_new(boringssl_cipher(id), key.data(), key.size(), tag_size));
if (ctx == nullptr) {
throw openssl_error();
}
auto ct = bytes(pt.size() + tag_size);
auto out_len = ct.size();
if (1 != EVP_AEAD_CTX_seal(ctx.get(),
ct.data(),
&out_len,
ct.size(),
nonce.data(),
nonce.size(),
pt.data(),
pt.size(),
aad.data(),
aad.size())) {
throw openssl_error();
}
return ct;
#else
auto ctx = make_typed_unique(EVP_CIPHER_CTX_new());
if (ctx == nullptr) {
throw openssl_error();
}
const auto* cipher = openssl_cipher(id);
if (1 != EVP_EncryptInit(ctx.get(), cipher, key.data(), nonce.data())) {
throw openssl_error();
}
int outlen = 0;
if (!aad.empty()) {
if (1 != EVP_EncryptUpdate(ctx.get(),
nullptr,
&outlen,
aad.data(),
static_cast<int>(aad.size()))) {
throw openssl_error();
}
}
bytes ct(pt.size());
if (1 != EVP_EncryptUpdate(ctx.get(),
ct.data(),
&outlen,
pt.data(),
static_cast<int>(pt.size()))) {
throw openssl_error();
}
// Providing nullptr as an argument is safe here because this
// function never writes with GCM; it only computes the tag
if (1 != EVP_EncryptFinal(ctx.get(), nullptr, &outlen)) {
throw openssl_error();
}
bytes tag(tag_size);
if (1 != EVP_CIPHER_CTX_ctrl(ctx.get(),
EVP_CTRL_GCM_GET_TAG,
static_cast<int>(tag_size),
tag.data())) {
throw openssl_error();
}
ct += tag;
return ct;
#endif // WITH_BORINGSSL
}
std::optional<bytes>
AEADCipher::open(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& ct) const
{
if (ct.size() < tag_size) {
throw std::runtime_error("AEAD ciphertext smaller than tag size");
}
#if WITH_BORINGSSL
auto ctx = make_typed_unique(EVP_AEAD_CTX_new(
boringssl_cipher(id), key.data(), key.size(), cipher_tag_size(id)));
if (ctx == nullptr) {
throw openssl_error();
}
auto pt = bytes(ct.size() - tag_size);
auto out_len = pt.size();
if (1 != EVP_AEAD_CTX_open(ctx.get(),
pt.data(),
&out_len,
pt.size(),
nonce.data(),
nonce.size(),
ct.data(),
ct.size(),
aad.data(),
aad.size())) {
throw openssl_error();
}
return pt;
#else
auto ctx = make_typed_unique(EVP_CIPHER_CTX_new());
if (ctx == nullptr) {
throw openssl_error();
}
const auto* cipher = openssl_cipher(id);
if (1 != EVP_DecryptInit(ctx.get(), cipher, key.data(), nonce.data())) {
throw openssl_error();
}
auto inner_ct_size = ct.size() - tag_size;
auto tag = ct.slice(inner_ct_size, ct.size());
if (1 != EVP_CIPHER_CTX_ctrl(ctx.get(),
EVP_CTRL_GCM_SET_TAG,
static_cast<int>(tag_size),
tag.data())) {
throw openssl_error();
}
int out_size = 0;
if (!aad.empty()) {
if (1 != EVP_DecryptUpdate(ctx.get(),
nullptr,
&out_size,
aad.data(),
static_cast<int>(aad.size()))) {
throw openssl_error();
}
}
bytes pt(inner_ct_size);
if (1 != EVP_DecryptUpdate(ctx.get(),
pt.data(),
&out_size,
ct.data(),
static_cast<int>(inner_ct_size))) {
throw openssl_error();
}
// Providing nullptr as an argument is safe here because this
// function never writes with GCM; it only verifies the tag
if (1 != EVP_DecryptFinal(ctx.get(), nullptr, &out_size)) {
throw std::runtime_error("AEAD authentication failure");
}
return pt;
#endif // WITH_BORINGSSL
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,45 @@
#pragma once
#include <hpke/hpke.h>
namespace mlspp::hpke {
struct ExportOnlyCipher : public AEAD
{
ExportOnlyCipher();
~ExportOnlyCipher() override = default;
bytes seal(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& pt) const override;
std::optional<bytes> open(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& ct) const override;
};
struct AEADCipher : public AEAD
{
template<AEAD::ID id>
static const AEADCipher& get();
~AEADCipher() override = default;
bytes seal(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& pt) const override;
std::optional<bytes> open(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& ct) const override;
private:
const size_t tag_size;
AEADCipher(AEAD::ID id_in);
friend AEADCipher make_aead(AEAD::ID cipher_in);
};
} // namespace mlspp::hpke

105
DPP/mlspp/lib/hpke/src/base64.cpp Executable file
View File

@@ -0,0 +1,105 @@
#include <hpke/base64.h>
#include "openssl_common.h"
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/evp.h>
namespace mlspp::hpke {
std::string
to_base64(const bytes& data)
{
if (data.empty()) {
return "";
}
#if WITH_BORINGSSL
const auto data_size = data.size();
#else
const auto data_size = static_cast<int>(data.size());
#endif
// base64 encoding produces 4 characters for every 3 input bytes (rounded up)
const auto out_size = (data_size + 2) / 3 * 4;
auto out = bytes(out_size + 1); // NUL terminator
const auto result = EVP_EncodeBlock(out.data(), data.data(), data_size);
if (result != out_size) {
throw openssl_error();
}
out.resize(out.size() - 1); // strip NUL terminator
return to_ascii(out);
}
std::string
to_base64url(const bytes& data)
{
if (data.empty()) {
return "";
}
auto encoded = to_base64(data);
auto pad_start = encoded.find_first_of('=');
if (pad_start != std::string::npos) {
encoded = encoded.substr(0, pad_start);
}
std::replace(encoded.begin(), encoded.end(), '+', '-');
std::replace(encoded.begin(), encoded.end(), '/', '_');
return encoded;
}
bytes
from_base64(const std::string& enc)
{
if (enc.length() == 0) {
return {};
}
if (enc.length() % 4 != 0) {
throw std::runtime_error("Base64 length is not divisible by 4");
}
const auto in = from_ascii(enc);
const auto in_size = static_cast<int>(in.size());
const auto out_size = in_size / 4 * 3;
auto out = bytes(out_size);
const auto result = EVP_DecodeBlock(out.data(), in.data(), in_size);
if (result != out_size) {
throw openssl_error();
}
if (enc.substr(enc.length() - 2, enc.length()) == "==") {
out.resize(out.size() - 2);
} else if (enc.substr(enc.length() - 1, enc.length()) == "=") {
out.resize(out.size() - 1);
}
return out;
}
bytes
from_base64url(const std::string& enc)
{
if (enc.empty()) {
return {};
}
auto enc_copy = enc;
std::replace(enc_copy.begin(), enc_copy.end(), '-', '+');
std::replace(enc_copy.begin(), enc_copy.end(), '_', '/');
while (enc_copy.length() % 4 != 0) {
enc_copy += "=";
}
return from_base64(enc_copy);
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,539 @@
#include "group.h"
#include "openssl_common.h"
#include "rsa.h"
#include <hpke/certificate.h>
#include <hpke/signature.h>
#include <memory>
#include <openssl/err.h>
#include <openssl/pem.h>
#include <openssl/x509.h>
#include <openssl/x509v3.h>
#include <tls/compat.h>
namespace mlspp::hpke {
///
/// Utility functions
///
static std::optional<bytes>
asn1_octet_string_to_bytes(const ASN1_OCTET_STRING* octets)
{
if (octets == nullptr) {
return std::nullopt;
}
const auto* ptr = ASN1_STRING_get0_data(octets);
const auto len = ASN1_STRING_length(octets);
// NOLINTNEXTLINE (cppcoreguidelines-pro-bounds-pointer-arithmetic)
return std::vector<uint8_t>(ptr, ptr + len);
}
static std::string
asn1_string_to_std_string(const ASN1_STRING* asn1_string)
{
const auto* data = ASN1_STRING_get0_data(asn1_string);
const auto data_size = static_cast<size_t>(ASN1_STRING_length(asn1_string));
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
auto str = std::string(reinterpret_cast<const char*>(data));
if (str.size() != data_size) {
throw std::runtime_error("Malformed ASN.1 string");
}
return str;
}
static std::chrono::system_clock::time_point
asn1_time_to_chrono(const ASN1_TIME* asn1_time)
{
auto epoch_chrono = std::chrono::system_clock::time_point();
auto epoch_time_t = std::chrono::system_clock::to_time_t(epoch_chrono);
auto epoch_asn1 = make_typed_unique(ASN1_TIME_set(nullptr, epoch_time_t));
if (!epoch_asn1) {
throw openssl_error();
}
auto secs = int(0);
auto days = int(0);
if (ASN1_TIME_diff(&days, &secs, epoch_asn1.get(), asn1_time) != 1) {
throw openssl_error();
}
auto delta = std::chrono::seconds(secs) + std::chrono::hours(24 * days);
return std::chrono::system_clock::time_point(delta);
}
///
/// ParsedCertificate
///
const int Certificate::NameType::organization = NID_organizationName;
const int Certificate::NameType::common_name = NID_commonName;
const int Certificate::NameType::organizational_unit =
NID_organizationalUnitName;
const int Certificate::NameType::country = NID_countryName;
const int Certificate::NameType::serial_number = NID_serialNumber;
const int Certificate::NameType::state_or_province_name =
NID_stateOrProvinceName;
struct RFC822Name
{
std::string value;
};
struct DNSName
{
std::string value;
};
using GeneralName = tls::var::variant<RFC822Name, DNSName>;
struct Certificate::ParsedCertificate
{
static std::unique_ptr<ParsedCertificate> parse(const bytes& der)
{
const auto* buf = der.data();
auto cert =
make_typed_unique(d2i_X509(nullptr, &buf, static_cast<int>(der.size())));
if (cert == nullptr) {
throw openssl_error();
}
return std::make_unique<ParsedCertificate>(cert.release());
}
static bytes compute_digest(const X509* cert)
{
const auto* md = EVP_sha256();
auto digest = bytes(EVP_MD_size(md));
unsigned int out_size = 0;
if (1 != X509_digest(cert, md, digest.data(), &out_size)) {
throw openssl_error();
}
return digest;
}
// Note: This method does not implement total general name parsing.
// Duplicate entries are not supported; if they are present, the last one
// presented by OpenSSL is chosen.
static ParsedName parse_names(const X509_NAME* x509_name)
{
if (x509_name == nullptr) {
throw openssl_error();
}
ParsedName parsed_name;
for (int i = X509_NAME_entry_count(x509_name) - 1; i >= 0; i--) {
auto* entry = X509_NAME_get_entry(x509_name, i);
if (entry == nullptr) {
continue;
}
auto* oid = X509_NAME_ENTRY_get_object(entry);
auto* asn_str = X509_NAME_ENTRY_get_data(entry);
if (oid == nullptr || asn_str == nullptr) {
continue;
}
const int nid = OBJ_obj2nid(oid);
const std::string parsed_value = asn1_string_to_std_string(asn_str);
parsed_name[nid] = parsed_value;
}
return parsed_name;
}
// Parse Subject Key Identifier Extension
static std::optional<bytes> parse_skid(X509* cert)
{
return asn1_octet_string_to_bytes(X509_get0_subject_key_id(cert));
}
// Parse Authority Key Identifier
static std::optional<bytes> parse_akid(X509* cert)
{
return asn1_octet_string_to_bytes(X509_get0_authority_key_id(cert));
}
static std::vector<GeneralName> parse_san(X509* cert)
{
std::vector<GeneralName> names;
#ifdef WITH_BORINGSSL
using san_names_nb_t = size_t;
#else
using san_names_nb_t = int;
#endif
san_names_nb_t san_names_nb = 0;
auto* ext_ptr =
X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
auto* san_ptr = reinterpret_cast<STACK_OF(GENERAL_NAME)*>(ext_ptr);
const auto san_names = make_typed_unique(san_ptr);
san_names_nb = sk_GENERAL_NAME_num(san_names.get());
// Check each name within the extension
for (san_names_nb_t i = 0; i < san_names_nb; i++) {
auto* current_name = sk_GENERAL_NAME_value(san_names.get(), i);
if (current_name->type == GEN_DNS) {
const auto dns_name =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access)
asn1_string_to_std_string(current_name->d.dNSName);
names.emplace_back(DNSName{ dns_name });
} else if (current_name->type == GEN_EMAIL) {
const auto email =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access
asn1_string_to_std_string(current_name->d.rfc822Name);
names.emplace_back(RFC822Name{ email });
}
}
return names;
}
explicit ParsedCertificate(X509* x509_in)
: x509(x509_in, typed_delete<X509>)
, pub_key_id(public_key_algorithm(x509.get()))
, sig_algo(signature_algorithm(x509.get()))
, issuer_hash(X509_issuer_name_hash(x509.get()))
, subject_hash(X509_subject_name_hash(x509.get()))
, issuer(parse_names(X509_get_issuer_name(x509.get())))
, subject(parse_names(X509_get_subject_name(x509.get())))
, subject_key_id(parse_skid(x509.get()))
, authority_key_id(parse_akid(x509.get()))
, sub_alt_names(parse_san(x509.get()))
, is_ca(X509_check_ca(x509.get()) != 0)
, hash(compute_digest(x509.get()))
, not_before(asn1_time_to_chrono(X509_get0_notBefore(x509.get())))
, not_after(asn1_time_to_chrono(X509_get0_notAfter(x509.get())))
{
}
ParsedCertificate(const ParsedCertificate& other)
: x509(nullptr, typed_delete<X509>)
, pub_key_id(public_key_algorithm(other.x509.get()))
, sig_algo(signature_algorithm(other.x509.get()))
, issuer_hash(other.issuer_hash)
, subject_hash(other.subject_hash)
, issuer(other.issuer)
, subject(other.subject)
, subject_key_id(other.subject_key_id)
, authority_key_id(other.authority_key_id)
, sub_alt_names(other.sub_alt_names)
, is_ca(other.is_ca)
, hash(other.hash)
, not_before(other.not_before)
, not_after(other.not_after)
{
if (1 != X509_up_ref(other.x509.get())) {
throw openssl_error();
}
x509.reset(other.x509.get());
}
static Signature::ID public_key_algorithm(X509* x509)
{
#if WITH_BORINGSSL
const auto pub = make_typed_unique(X509_get_pubkey(x509));
const auto* pub_ptr = pub.get();
#else
const auto* pub_ptr = X509_get0_pubkey(x509);
#endif
switch (EVP_PKEY_base_id(pub_ptr)) {
case EVP_PKEY_ED25519:
return Signature::ID::Ed25519;
#if !defined(WITH_BORINGSSL)
case EVP_PKEY_ED448:
return Signature::ID::Ed448;
#endif
case EVP_PKEY_EC: {
auto key_size = EVP_PKEY_bits(pub_ptr);
switch (key_size) {
case 256:
return Signature::ID::P256_SHA256;
case 384:
return Signature::ID::P384_SHA384;
case 521:
return Signature::ID::P521_SHA512;
default:
throw std::runtime_error("Unknown curve");
}
}
case EVP_PKEY_RSA:
// RSA public keys are not specific to an algorithm
return Signature::ID::RSA_SHA256;
default:
break;
}
throw std::runtime_error("Unsupported public key algorithm");
}
static Signature::ID signature_algorithm(X509* cert)
{
auto nid = X509_get_signature_nid(cert);
switch (nid) {
case EVP_PKEY_ED25519:
return Signature::ID::Ed25519;
#if !defined(WITH_BORINGSSL)
case EVP_PKEY_ED448:
return Signature::ID::Ed448;
#endif
case NID_ecdsa_with_SHA256:
return Signature::ID::P256_SHA256;
case NID_ecdsa_with_SHA384:
return Signature::ID::P384_SHA384;
case NID_ecdsa_with_SHA512:
return Signature::ID::P521_SHA512;
case NID_sha1WithRSAEncryption:
// We fall through to SHA256 for SHA1 because we do not implement SHA-1.
case NID_sha256WithRSAEncryption:
return Signature::ID::RSA_SHA256;
case NID_sha384WithRSAEncryption:
return Signature::ID::RSA_SHA384;
case NID_sha512WithRSAEncryption:
return Signature::ID::RSA_SHA512;
default:
break;
}
throw std::runtime_error("Unsupported signature algorithm");
}
typed_unique_ptr<EVP_PKEY> public_key() const
{
return make_typed_unique<EVP_PKEY>(X509_get_pubkey(x509.get()));
}
Certificate::ExpirationStatus expiration_status() const
{
auto now = std::chrono::system_clock::now();
if (now < not_before) {
return Certificate::ExpirationStatus::inactive;
}
if (now > not_after) {
return Certificate::ExpirationStatus::expired;
}
return Certificate::ExpirationStatus::active;
}
bytes raw() const
{
auto out = bytes(i2d_X509(x509.get(), nullptr));
auto* ptr = out.data();
i2d_X509(x509.get(), &ptr);
return out;
}
typed_unique_ptr<X509> x509;
const Signature::ID pub_key_id;
const Signature::ID sig_algo;
const uint64_t issuer_hash;
const uint64_t subject_hash;
const ParsedName issuer;
const ParsedName subject;
const std::optional<bytes> subject_key_id;
const std::optional<bytes> authority_key_id;
const std::vector<GeneralName> sub_alt_names;
const bool is_ca;
const bytes hash;
const std::chrono::system_clock::time_point not_before;
const std::chrono::system_clock::time_point not_after;
};
///
/// Certificate
///
static std::unique_ptr<Signature::PublicKey>
signature_key(EVP_PKEY* pkey)
{
switch (EVP_PKEY_base_id(pkey)) {
case EVP_PKEY_RSA:
return std::make_unique<RSASignature::PublicKey>(pkey);
case EVP_PKEY_ED448:
case EVP_PKEY_ED25519:
case EVP_PKEY_EC:
return std::make_unique<EVPGroup::PublicKey>(pkey);
default:
throw std::runtime_error("Unsupported algorithm");
}
}
Certificate::Certificate(std::unique_ptr<ParsedCertificate>&& parsed_cert_in)
: parsed_cert(std::move(parsed_cert_in))
, public_key(signature_key(parsed_cert->public_key().release()))
, raw(parsed_cert->raw())
{
}
Certificate::Certificate(const bytes& der)
: parsed_cert(ParsedCertificate::parse(der))
, public_key(signature_key(parsed_cert->public_key().release()))
, raw(der)
{
}
Certificate::Certificate(const Certificate& other)
: parsed_cert(std::make_unique<ParsedCertificate>(*other.parsed_cert))
, public_key(signature_key(parsed_cert->public_key().release()))
, raw(other.raw)
{
}
Certificate::~Certificate() = default;
std::vector<Certificate>
Certificate::parse_pem(const bytes& pem)
{
auto size_int = static_cast<int>(pem.size());
auto bio = make_typed_unique<BIO>(BIO_new_mem_buf(pem.data(), size_int));
if (!bio) {
throw openssl_error();
}
auto certs = std::vector<Certificate>();
while (true) {
auto x509 = make_typed_unique<X509>(
PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
if (!x509) {
// NOLINTNEXTLINE(hicpp-signed-bitwise)
auto err = ERR_GET_REASON(ERR_peek_last_error());
if (err == PEM_R_NO_START_LINE) {
// No more objects to read
break;
}
throw openssl_error();
}
auto parsed = std::make_unique<ParsedCertificate>(x509.release());
certs.emplace_back(std::move(parsed));
}
return certs;
}
bool
Certificate::valid_from(const Certificate& parent) const
{
auto pub = parent.parsed_cert->public_key();
return (1 == X509_verify(parsed_cert->x509.get(), pub.get()));
}
uint64_t
Certificate::issuer_hash() const
{
return parsed_cert->issuer_hash;
}
uint64_t
Certificate::subject_hash() const
{
return parsed_cert->subject_hash;
}
Certificate::ParsedName
Certificate::subject() const
{
return parsed_cert->subject;
}
Certificate::ParsedName
Certificate::issuer() const
{
return parsed_cert->issuer;
}
bool
Certificate::is_ca() const
{
return parsed_cert->is_ca;
}
Certificate::ExpirationStatus
Certificate::expiration_status() const
{
return parsed_cert->expiration_status();
}
std::optional<bytes>
Certificate::subject_key_id() const
{
return parsed_cert->subject_key_id;
}
std::optional<bytes>
Certificate::authority_key_id() const
{
return parsed_cert->authority_key_id;
}
std::vector<std::string>
Certificate::email_addresses() const
{
std::vector<std::string> emails;
for (const auto& name : parsed_cert->sub_alt_names) {
if (tls::var::holds_alternative<RFC822Name>(name)) {
emails.emplace_back(tls::var::get<RFC822Name>(name).value);
}
}
return emails;
}
std::vector<std::string>
Certificate::dns_names() const
{
std::vector<std::string> domains;
for (const auto& name : parsed_cert->sub_alt_names) {
if (tls::var::holds_alternative<DNSName>(name)) {
domains.emplace_back(tls::var::get<DNSName>(name).value);
}
}
return domains;
}
bytes
Certificate::hash() const
{
return parsed_cert->hash;
}
std::chrono::system_clock::time_point
Certificate::not_before() const
{
return parsed_cert->not_before;
}
std::chrono::system_clock::time_point
Certificate::not_after() const
{
return parsed_cert->not_after;
}
Signature::ID
Certificate::public_key_algorithm() const
{
return parsed_cert->pub_key_id;
}
Signature::ID
Certificate::signature_algorithm() const
{
return parsed_cert->sig_algo;
}
bool
operator==(const Certificate& lhs, const Certificate& rhs)
{
return lhs.raw == rhs.raw;
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,20 @@
#include "common.h"
namespace mlspp::hpke {
bytes
i2osp(uint64_t val, size_t size)
{
auto out = bytes(size, 0);
auto max = size;
if (size > 8) {
max = 8;
}
for (size_t i = 0; i < max; i++) {
out.at(size - i - 1) = static_cast<uint8_t>(val >> (8 * i));
}
return out;
}
} // namespace mlspp::hpke

10
DPP/mlspp/lib/hpke/src/common.h Executable file
View File

@@ -0,0 +1,10 @@
#pragma once
#include <hpke/hpke.h>
namespace mlspp::hpke {
bytes
i2osp(uint64_t val, size_t size);
} // namespace mlspp::hpke

216
DPP/mlspp/lib/hpke/src/dhkem.cpp Executable file
View File

@@ -0,0 +1,216 @@
#include "dhkem.h"
#include "common.h"
namespace mlspp::hpke {
DHKEM::PrivateKey::PrivateKey(Group::PrivateKey* group_priv_in)
: group_priv(group_priv_in)
{
}
std::unique_ptr<KEM::PublicKey>
DHKEM::PrivateKey::public_key() const
{
return group_priv->public_key();
}
DHKEM
make_dhkem(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in)
{
return { kem_id_in, group_in, kdf_in };
}
template<>
const DHKEM&
DHKEM::get<KEM::ID::DHKEM_P256_SHA256>()
{
static const auto instance = make_dhkem(KEM::ID::DHKEM_P256_SHA256,
Group::get<Group::ID::P256>(),
KDF::get<KDF::ID::HKDF_SHA256>());
return instance;
}
template<>
const DHKEM&
DHKEM::get<KEM::ID::DHKEM_P384_SHA384>()
{
static const auto instance = make_dhkem(KEM::ID::DHKEM_P384_SHA384,
Group::get<Group::ID::P384>(),
KDF::get<KDF::ID::HKDF_SHA384>());
return instance;
}
template<>
const DHKEM&
DHKEM::get<KEM::ID::DHKEM_P521_SHA512>()
{
static const auto instance = make_dhkem(KEM::ID::DHKEM_P521_SHA512,
Group::get<Group::ID::P521>(),
KDF::get<KDF::ID::HKDF_SHA512>());
return instance;
}
template<>
const DHKEM&
DHKEM::get<KEM::ID::DHKEM_X25519_SHA256>()
{
static const auto instance = make_dhkem(KEM::ID::DHKEM_X25519_SHA256,
Group::get<Group::ID::X25519>(),
KDF::get<KDF::ID::HKDF_SHA256>());
return instance;
}
#if !defined(WITH_BORINGSSL)
template<>
const DHKEM&
DHKEM::get<KEM::ID::DHKEM_X448_SHA512>()
{
static const auto instance = make_dhkem(KEM::ID::DHKEM_X448_SHA512,
Group::get<Group::ID::X448>(),
KDF::get<KDF::ID::HKDF_SHA512>());
return instance;
}
#endif
DHKEM::DHKEM(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in)
: KEM(kem_id_in,
kdf_in.hash_size,
group_in.pk_size,
group_in.pk_size,
group_in.sk_size)
, group(group_in)
, kdf(kdf_in)
{
static const auto label_kem = from_ascii("KEM");
suite_id = label_kem + i2osp(uint16_t(kem_id_in), 2);
}
std::unique_ptr<KEM::PrivateKey>
DHKEM::generate_key_pair() const
{
return std::make_unique<DHKEM::PrivateKey>(
group.generate_key_pair().release());
}
std::unique_ptr<KEM::PrivateKey>
DHKEM::derive_key_pair(const bytes& ikm) const
{
return std::make_unique<DHKEM::PrivateKey>(
group.derive_key_pair(suite_id, ikm).release());
}
bytes
DHKEM::serialize(const KEM::PublicKey& pk) const
{
const auto& gpk = dynamic_cast<const Group::PublicKey&>(pk);
return group.serialize(gpk);
}
std::unique_ptr<KEM::PublicKey>
DHKEM::deserialize(const bytes& enc) const
{
return group.deserialize(enc);
}
bytes
DHKEM::serialize_private(const KEM::PrivateKey& sk) const
{
const auto& gsk = dynamic_cast<const PrivateKey&>(sk);
return group.serialize_private(*gsk.group_priv);
}
std::unique_ptr<KEM::PrivateKey>
DHKEM::deserialize_private(const bytes& skm) const
{
return std::make_unique<PrivateKey>(group.deserialize_private(skm).release());
}
std::pair<bytes, bytes>
DHKEM::encap(const KEM::PublicKey& pkR) const
{
const auto& gpkR = dynamic_cast<const Group::PublicKey&>(pkR);
auto skE = group.generate_key_pair();
auto pkE = skE->public_key();
auto zz = group.dh(*skE, gpkR);
auto enc = group.serialize(*pkE);
auto pkRm = group.serialize(gpkR);
auto kem_context = enc + pkRm;
auto shared_secret = extract_and_expand(zz, kem_context);
return std::make_pair(shared_secret, enc);
}
bytes
DHKEM::decap(const bytes& enc, const KEM::PrivateKey& skR) const
{
const auto& gskR = dynamic_cast<const PrivateKey&>(skR);
auto pkR = gskR.group_priv->public_key();
auto pkE = group.deserialize(enc);
auto zz = group.dh(*gskR.group_priv, *pkE);
auto pkRm = group.serialize(*pkR);
auto kem_context = enc + pkRm;
return extract_and_expand(zz, kem_context);
}
std::pair<bytes, bytes>
DHKEM::auth_encap(const KEM::PublicKey& pkR, const KEM::PrivateKey& skS) const
{
const auto& gpkR = dynamic_cast<const Group::PublicKey&>(pkR);
const auto& gskS = dynamic_cast<const PrivateKey&>(skS);
auto skE = group.generate_key_pair();
auto pkE = skE->public_key();
auto pkS = gskS.group_priv->public_key();
auto zzER = group.dh(*skE, gpkR);
auto zzSR = group.dh(*gskS.group_priv, gpkR);
auto zz = zzER + zzSR;
auto enc = group.serialize(*pkE);
auto pkRm = group.serialize(gpkR);
auto pkSm = group.serialize(*pkS);
auto kem_context = enc + pkRm + pkSm;
auto shared_secret = extract_and_expand(zz, kem_context);
return std::make_pair(shared_secret, enc);
}
bytes
DHKEM::auth_decap(const bytes& enc,
const KEM::PublicKey& pkS,
const KEM::PrivateKey& skR) const
{
const auto& gpkS = dynamic_cast<const Group::PublicKey&>(pkS);
const auto& gskR = dynamic_cast<const PrivateKey&>(skR);
auto pkE = group.deserialize(enc);
auto pkR = gskR.group_priv->public_key();
auto zzER = group.dh(*gskR.group_priv, *pkE);
auto zzSR = group.dh(*gskR.group_priv, gpkS);
auto zz = zzER + zzSR;
auto pkRm = group.serialize(*pkR);
auto pkSm = group.serialize(gpkS);
auto kem_context = enc + pkRm + pkSm;
return extract_and_expand(zz, kem_context);
}
bytes
DHKEM::extract_and_expand(const bytes& dh, const bytes& kem_context) const
{
static const auto label_eae_prk = from_ascii("eae_prk");
static const auto label_shared_secret = from_ascii("shared_secret");
auto eae_prk = kdf.labeled_extract(suite_id, {}, label_eae_prk, dh);
return kdf.labeled_expand(
suite_id, eae_prk, label_shared_secret, kem_context, secret_size);
}
} // namespace mlspp::hpke

57
DPP/mlspp/lib/hpke/src/dhkem.h Executable file
View File

@@ -0,0 +1,57 @@
#pragma once
#include <hpke/hpke.h>
#include "group.h"
namespace mlspp::hpke {
struct DHKEM : public KEM
{
struct PrivateKey : public KEM::PrivateKey
{
PrivateKey(Group::PrivateKey* group_priv_in);
std::unique_ptr<KEM::PublicKey> public_key() const override;
std::unique_ptr<Group::PrivateKey> group_priv;
};
template<KEM::ID>
static const DHKEM& get();
~DHKEM() override = default;
std::unique_ptr<KEM::PrivateKey> generate_key_pair() const override;
std::unique_ptr<KEM::PrivateKey> derive_key_pair(
const bytes& ikm) const override;
bytes serialize(const KEM::PublicKey& pk) const override;
std::unique_ptr<KEM::PublicKey> deserialize(const bytes& enc) const override;
bytes serialize_private(const KEM::PrivateKey& sk) const override;
std::unique_ptr<KEM::PrivateKey> deserialize_private(
const bytes& skm) const override;
std::pair<bytes, bytes> encap(const KEM::PublicKey& pk) const override;
bytes decap(const bytes& enc, const KEM::PrivateKey& sk) const override;
std::pair<bytes, bytes> auth_encap(const KEM::PublicKey& pkR,
const KEM::PrivateKey& skS) const override;
bytes auth_decap(const bytes& enc,
const KEM::PublicKey& pkS,
const KEM::PrivateKey& skR) const override;
private:
const Group& group;
const KDF& kdf;
bytes suite_id;
bytes extract_and_expand(const bytes& dh, const bytes& kem_context) const;
DHKEM(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in);
friend DHKEM make_dhkem(KEM::ID kem_id_in,
const Group& group_in,
const KDF& kdf_in);
};
} // namespace mlspp::hpke

187
DPP/mlspp/lib/hpke/src/digest.cpp Executable file
View File

@@ -0,0 +1,187 @@
#include <hpke/digest.h>
#include <openssl/evp.h>
#include <openssl/hmac.h>
#if defined(WITH_OPENSSL3)
#include <openssl/core_names.h>
#endif
#include "openssl_common.h"
namespace mlspp::hpke {
static const EVP_MD*
openssl_digest_type(Digest::ID digest)
{
switch (digest) {
case Digest::ID::SHA256:
return EVP_sha256();
case Digest::ID::SHA384:
return EVP_sha384();
case Digest::ID::SHA512:
return EVP_sha512();
default:
throw std::runtime_error("Unsupported ciphersuite");
}
}
#if defined(WITH_OPENSSL3)
static std::string
openssl_digest_name(Digest::ID digest)
{
switch (digest) {
case Digest::ID::SHA256:
return OSSL_DIGEST_NAME_SHA2_256;
case Digest::ID::SHA384:
return OSSL_DIGEST_NAME_SHA2_384;
case Digest::ID::SHA512:
return OSSL_DIGEST_NAME_SHA2_512;
default:
throw std::runtime_error("Unsupported digest algorithm");
}
}
#endif
template<>
const Digest&
Digest::get<Digest::ID::SHA256>()
{
static const Digest instance(Digest::ID::SHA256);
return instance;
}
template<>
const Digest&
Digest::get<Digest::ID::SHA384>()
{
static const Digest instance(Digest::ID::SHA384);
return instance;
}
template<>
const Digest&
Digest::get<Digest::ID::SHA512>()
{
static const Digest instance(Digest::ID::SHA512);
return instance;
}
Digest::Digest(Digest::ID id_in)
: id(id_in)
, hash_size(EVP_MD_size(openssl_digest_type(id_in)))
{
}
bytes
Digest::hash(const bytes& data) const
{
auto md = bytes(hash_size);
unsigned int size = 0;
const auto* type = openssl_digest_type(id);
if (1 !=
EVP_Digest(data.data(), data.size(), md.data(), &size, type, nullptr)) {
throw openssl_error();
}
return md;
}
bytes
Digest::hmac(const bytes& key, const bytes& data) const
{
auto md = bytes(hash_size);
unsigned int size = 0;
const auto* type = openssl_digest_type(id);
if (nullptr == HMAC(type,
key.data(),
static_cast<int>(key.size()),
data.data(),
static_cast<int>(data.size()),
md.data(),
&size)) {
throw openssl_error();
}
return md;
}
bytes
Digest::hmac_for_hkdf_extract(const bytes& key, const bytes& data) const
{
#if defined(WITH_OPENSSL3)
auto digest_name = openssl_digest_name(id);
std::array<OSSL_PARAM, 2> params = {
OSSL_PARAM_construct_utf8_string(
OSSL_ALG_PARAM_DIGEST, digest_name.data(), 0),
OSSL_PARAM_construct_end()
};
const auto mac =
make_typed_unique(EVP_MAC_fetch(nullptr, OSSL_MAC_NAME_HMAC, nullptr));
const auto ctx = make_typed_unique(EVP_MAC_CTX_new(mac.get()));
#else
const auto* type = openssl_digest_type(id);
auto ctx = make_typed_unique(HMAC_CTX_new());
#endif
if (ctx == nullptr) {
throw openssl_error();
}
// Some FIPS-enabled libraries are overly conservative in their interpretation
// of NIST SP 800-131A, which requires HMAC keys to be at least 112 bits long.
// That document does not impose that requirement on HKDF, so we disable FIPS
// enforcement for purposes of HKDF.
//
// https://doi.org/10.6028/NIST.SP.800-131Ar2
auto key_size = static_cast<int>(key.size());
// OpenSSL 3 does not support the flag EVP_MD_CTX_FLAG_NON_FIPS_ALLOW anymore.
// However, OpenSSL 3 in FIPS mode doesn't seem to check the HMAC key size
// constraint.
#if !defined(WITH_OPENSSL3) && !defined(WITH_BORINGSSL)
static const auto fips_min_hmac_key_len = 14;
if (FIPS_mode() != 0 && key_size < fips_min_hmac_key_len) {
HMAC_CTX_set_flags(ctx.get(), EVP_MD_CTX_FLAG_NON_FIPS_ALLOW);
}
#endif
// Guard against sending nullptr to HMAC_Init_ex
const auto* key_data = key.data();
const auto non_null_zero_length_key = uint8_t(0);
if (key_data == nullptr) {
key_data = &non_null_zero_length_key;
}
auto md = bytes(hash_size);
#if defined(WITH_OPENSSL3)
if (1 != EVP_MAC_init(ctx.get(), key_data, key_size, params.data())) {
throw openssl_error();
}
if (1 != EVP_MAC_update(ctx.get(), data.data(), data.size())) {
throw openssl_error();
}
size_t size = 0;
if (1 != EVP_MAC_final(ctx.get(), md.data(), &size, hash_size)) {
throw openssl_error();
}
#else
if (1 != HMAC_Init_ex(ctx.get(), key_data, key_size, type, nullptr)) {
throw openssl_error();
}
if (1 != HMAC_Update(ctx.get(), data.data(), data.size())) {
throw openssl_error();
}
unsigned int size = 0;
if (1 != HMAC_Final(ctx.get(), md.data(), &size)) {
throw openssl_error();
}
#endif
return md;
}
} // namespace mlspp::hpke

1077
DPP/mlspp/lib/hpke/src/group.cpp Executable file

File diff suppressed because it is too large Load Diff

116
DPP/mlspp/lib/hpke/src/group.h Executable file
View File

@@ -0,0 +1,116 @@
#pragma once
#include <hpke/hpke.h>
#include <hpke/signature.h>
#include "openssl_common.h"
#include <openssl/evp.h>
namespace mlspp::hpke {
struct Group
{
enum struct ID : uint8_t
{
P256,
P384,
P521,
X25519,
X448,
Ed25519,
Ed448,
};
struct PublicKey
: public KEM::PublicKey
, public Signature::PublicKey
{
virtual ~PublicKey() = default;
};
struct PrivateKey
{
virtual ~PrivateKey() = default;
virtual std::unique_ptr<PublicKey> public_key() const = 0;
};
template<Group::ID id>
static const Group& get();
virtual ~Group() = default;
const ID id;
const size_t dh_size;
const size_t pk_size;
const size_t sk_size;
const std::string jwk_key_type;
const std::string jwk_curve_name;
virtual std::unique_ptr<PrivateKey> generate_key_pair() const = 0;
virtual std::unique_ptr<PrivateKey> derive_key_pair(
const bytes& suite_id,
const bytes& ikm) const = 0;
virtual bytes serialize(const PublicKey& pk) const = 0;
virtual std::unique_ptr<PublicKey> deserialize(const bytes& enc) const = 0;
virtual bytes serialize_private(const PrivateKey& sk) const = 0;
virtual std::unique_ptr<PrivateKey> deserialize_private(
const bytes& skm) const = 0;
virtual bytes dh(const PrivateKey& sk, const PublicKey& pk) const = 0;
virtual bytes sign(const bytes& data, const PrivateKey& sk) const = 0;
virtual bool verify(const bytes& data,
const bytes& sig,
const PublicKey& pk) const = 0;
virtual std::tuple<bytes, bytes> coordinates(const PublicKey& pk) const = 0;
virtual std::unique_ptr<PublicKey> public_key_from_coordinates(
const bytes& x,
const bytes& y) const = 0;
protected:
const KDF& kdf;
friend struct DHKEM;
Group(ID group_id_in, const KDF& kdf_in);
};
struct EVPGroup : public Group
{
EVPGroup(Group::ID group_id, const KDF& kdf);
struct PublicKey : public Group::PublicKey
{
explicit PublicKey(EVP_PKEY* pkey_in);
~PublicKey() override = default;
// NOLINTNEXTLINE(misc-non-private-member-variables-in-classes)
typed_unique_ptr<EVP_PKEY> pkey;
};
struct PrivateKey : public Group::PrivateKey
{
explicit PrivateKey(EVP_PKEY* pkey_in);
~PrivateKey() override = default;
std::unique_ptr<Group::PublicKey> public_key() const override;
// NOLINTNEXTLINE(misc-non-private-member-variables-in-classes)
typed_unique_ptr<EVP_PKEY> pkey;
};
std::unique_ptr<Group::PrivateKey> generate_key_pair() const override;
bytes dh(const Group::PrivateKey& sk,
const Group::PublicKey& pk) const override;
bytes sign(const bytes& data, const Group::PrivateKey& sk) const override;
bool verify(const bytes& data,
const bytes& sig,
const Group::PublicKey& pk) const override;
};
} // namespace mlspp::hpke

79
DPP/mlspp/lib/hpke/src/hkdf.cpp Executable file
View File

@@ -0,0 +1,79 @@
#include "hkdf.h"
#include "openssl_common.h"
#include <openssl/err.h>
#include <openssl/evp.h>
#include <stdexcept>
namespace mlspp::hpke {
template<>
const HKDF&
HKDF::get<Digest::ID::SHA256>()
{
static const HKDF instance(Digest::get<Digest::ID::SHA256>());
return instance;
}
template<>
const HKDF&
HKDF::get<Digest::ID::SHA384>()
{
static const HKDF instance(Digest::get<Digest::ID::SHA384>());
return instance;
}
template<>
const HKDF&
HKDF::get<Digest::ID::SHA512>()
{
static const HKDF instance(Digest::get<Digest::ID::SHA512>());
return instance;
}
static KDF::ID
digest_to_kdf(Digest::ID digest_id)
{
switch (digest_id) {
case Digest::ID::SHA256:
return KDF::ID::HKDF_SHA256;
case Digest::ID::SHA384:
return KDF::ID::HKDF_SHA384;
case Digest::ID::SHA512:
return KDF::ID::HKDF_SHA512;
}
throw std::runtime_error("Unsupported algorithm");
}
HKDF::HKDF(const Digest& digest_in)
: KDF(digest_to_kdf(digest_in.id), digest_in.hash_size)
, digest(digest_in)
{
}
bytes
HKDF::extract(const bytes& salt, const bytes& ikm) const
{
return digest.hmac_for_hkdf_extract(salt, ikm);
}
bytes
HKDF::expand(const bytes& prk, const bytes& info, size_t size) const
{
auto okm = bytes{};
auto i = uint8_t(0x00);
auto Ti = bytes{};
while (okm.size() < size) {
i += 1;
auto block = Ti + info + bytes{ i };
Ti = digest.hmac(prk, block);
okm += Ti;
}
okm.resize(size);
return okm;
}
} // namespace mlspp::hpke

24
DPP/mlspp/lib/hpke/src/hkdf.h Executable file
View File

@@ -0,0 +1,24 @@
#pragma once
#include <hpke/digest.h>
#include <hpke/hpke.h>
namespace mlspp::hpke {
struct HKDF : public KDF
{
template<Digest::ID digest_id>
static const HKDF& get();
~HKDF() override = default;
bytes extract(const bytes& salt, const bytes& ikm) const override;
bytes expand(const bytes& prk, const bytes& info, size_t size) const override;
private:
const Digest& digest;
explicit HKDF(const Digest& digest_in);
};
} // namespace mlspp::hpke

540
DPP/mlspp/lib/hpke/src/hpke.cpp Executable file
View File

@@ -0,0 +1,540 @@
#include <hpke/digest.h>
#include <hpke/hpke.h>
#include "aead_cipher.h"
#include "common.h"
#include "dhkem.h"
#include "hkdf.h"
#include <limits>
#include <stdexcept>
#include <string>
namespace mlspp::hpke {
///
/// Helper functions and constants
///
static const bytes&
label_exp()
{
static const bytes val = from_ascii("exp");
return val;
}
static const bytes&
label_hpke()
{
static const bytes val = from_ascii("HPKE");
return val;
}
static const bytes&
label_hpke_version()
{
static const bytes val = from_ascii("HPKE-v1");
return val;
}
static const bytes&
label_info_hash()
{
static const bytes val = from_ascii("info_hash");
return val;
}
static const bytes&
label_key()
{
static const bytes val = from_ascii("key");
return val;
}
static const bytes&
label_base_nonce()
{
static const bytes val = from_ascii("base_nonce");
return val;
}
static const bytes&
label_psk_id_hash()
{
static const bytes val = from_ascii("psk_id_hash");
return val;
}
static const bytes&
label_sec()
{
static const bytes val = from_ascii("sec");
return val;
}
static const bytes&
label_secret()
{
static const bytes val = from_ascii("secret");
return val;
}
///
/// Factory methods for primitives
///
KEM::KEM(ID id_in,
size_t secret_size_in,
size_t enc_size_in,
size_t pk_size_in,
size_t sk_size_in)
: id(id_in)
, secret_size(secret_size_in)
, enc_size(enc_size_in)
, pk_size(pk_size_in)
, sk_size(sk_size_in)
{
}
template<>
const KEM&
KEM::get<KEM::ID::DHKEM_P256_SHA256>()
{
return DHKEM::get<KEM::ID::DHKEM_P256_SHA256>();
}
template<>
const KEM&
KEM::get<KEM::ID::DHKEM_P384_SHA384>()
{
return DHKEM::get<KEM::ID::DHKEM_P384_SHA384>();
}
template<>
const KEM&
KEM::get<KEM::ID::DHKEM_P521_SHA512>()
{
return DHKEM::get<KEM::ID::DHKEM_P521_SHA512>();
}
template<>
const KEM&
KEM::get<KEM::ID::DHKEM_X25519_SHA256>()
{
return DHKEM::get<KEM::ID::DHKEM_X25519_SHA256>();
}
#if !defined(WITH_BORINGSSL)
template<>
const KEM&
KEM::get<KEM::ID::DHKEM_X448_SHA512>()
{
return DHKEM::get<KEM::ID::DHKEM_X448_SHA512>();
}
#endif
bytes
KEM::serialize_private(const KEM::PrivateKey& /* unused */) const
{
throw std::runtime_error("Not implemented");
}
std::unique_ptr<KEM::PrivateKey>
KEM::deserialize_private(const bytes& /* unused */) const
{
throw std::runtime_error("Not implemented");
}
std::pair<bytes, bytes>
KEM::auth_encap(const PublicKey& /* unused */,
const PrivateKey& /* unused */) const
{
throw std::runtime_error("Not implemented");
}
bytes
KEM::auth_decap(const bytes& /* unused */,
const PublicKey& /* unused */,
const PrivateKey& /* unused */) const
{
throw std::runtime_error("Not implemented");
}
template<>
const KDF&
KDF::get<KDF::ID::HKDF_SHA256>()
{
return HKDF::get<Digest::ID::SHA256>();
}
template<>
const KDF&
KDF::get<KDF::ID::HKDF_SHA384>()
{
return HKDF::get<Digest::ID::SHA384>();
}
template<>
const KDF&
KDF::get<KDF::ID::HKDF_SHA512>()
{
return HKDF::get<Digest::ID::SHA512>();
}
KDF::KDF(ID id_in, size_t hash_size_in)
: id(id_in)
, hash_size(hash_size_in)
{
}
bytes
KDF::labeled_extract(const bytes& suite_id,
const bytes& salt,
const bytes& label,
const bytes& ikm) const
{
auto labeled_ikm = label_hpke_version() + suite_id + label + ikm;
return extract(salt, labeled_ikm);
}
bytes
KDF::labeled_expand(const bytes& suite_id,
const bytes& prk,
const bytes& label,
const bytes& info,
size_t size) const
{
auto labeled_info =
i2osp(size, 2) + label_hpke_version() + suite_id + label + info;
return expand(prk, labeled_info, size);
}
template<>
const AEAD&
AEAD::get<AEAD::ID::AES_128_GCM>()
{
return AEADCipher::get<AEAD::ID::AES_128_GCM>();
}
template<>
const AEAD&
AEAD::get<AEAD::ID::AES_256_GCM>()
{
return AEADCipher::get<AEAD::ID::AES_256_GCM>();
}
template<>
const AEAD&
AEAD::get<AEAD::ID::CHACHA20_POLY1305>()
{
return AEADCipher::get<AEAD::ID::CHACHA20_POLY1305>();
}
template<>
const AEAD&
AEAD::get<AEAD::ID::export_only>()
{
static const auto export_only = ExportOnlyCipher{};
return export_only;
}
AEAD::AEAD(ID id_in, size_t key_size_in, size_t nonce_size_in)
: id(id_in)
, key_size(key_size_in)
, nonce_size(nonce_size_in)
{
}
///
/// Encryption Contexts
///
bytes
Context::do_export(const bytes& exporter_context, size_t size) const
{
return kdf.labeled_expand(
suite, exporter_secret, label_sec(), exporter_context, size);
}
bytes
Context::current_nonce() const
{
auto curr = i2osp(seq, aead.nonce_size);
return curr ^ nonce;
}
void
Context::increment_seq()
{
if (seq == std::numeric_limits<uint64_t>::max()) {
throw std::runtime_error("Sequence number overflow");
}
seq += 1;
}
Context::Context(bytes suite_in,
bytes key_in,
bytes nonce_in,
bytes exporter_secret_in,
const KDF& kdf_in,
const AEAD& aead_in)
: suite(std::move(suite_in))
, key(std::move(key_in))
, nonce(std::move(nonce_in))
, exporter_secret(std::move(exporter_secret_in))
, kdf(kdf_in)
, aead(aead_in)
, seq(0)
{
}
bool
operator==(const Context& lhs, const Context& rhs)
{
// TODO(RLB) Compare KDF and AEAD algorithms
auto suite = (lhs.suite == rhs.suite);
auto key = (lhs.key == rhs.key);
auto nonce = (lhs.nonce == rhs.nonce);
auto exporter_secret = (lhs.exporter_secret == rhs.exporter_secret);
auto seq = (lhs.seq == rhs.seq);
return suite && key && nonce && exporter_secret && seq;
}
SenderContext::SenderContext(Context&& c)
: Context(std::move(c))
{
}
bytes
SenderContext::seal(const bytes& aad, const bytes& pt)
{
auto ct = aead.seal(key, current_nonce(), aad, pt);
increment_seq();
return ct;
}
ReceiverContext::ReceiverContext(Context&& c)
: Context(std::move(c))
{
}
std::optional<bytes>
ReceiverContext::open(const bytes& aad, const bytes& ct)
{
auto maybe_pt = aead.open(key, current_nonce(), aad, ct);
increment_seq();
return maybe_pt;
}
///
/// HPKE
///
static const bytes default_psk = {};
static const bytes default_psk_id = {};
static bytes
suite_id(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id)
{
return label_hpke() + i2osp(static_cast<uint64_t>(kem_id), 2) +
i2osp(static_cast<uint64_t>(kdf_id), 2) +
i2osp(static_cast<uint64_t>(aead_id), 2);
}
static const KEM&
select_kem(KEM::ID id)
{
switch (id) {
case KEM::ID::DHKEM_P256_SHA256:
return KEM::get<KEM::ID::DHKEM_P256_SHA256>();
case KEM::ID::DHKEM_P384_SHA384:
return KEM::get<KEM::ID::DHKEM_P384_SHA384>();
case KEM::ID::DHKEM_P521_SHA512:
return KEM::get<KEM::ID::DHKEM_P521_SHA512>();
case KEM::ID::DHKEM_X25519_SHA256:
return KEM::get<KEM::ID::DHKEM_X25519_SHA256>();
#if !defined(WITH_BORINGSSL)
case KEM::ID::DHKEM_X448_SHA512:
return KEM::get<KEM::ID::DHKEM_X448_SHA512>();
#endif
default:
throw std::runtime_error("Unsupported algorithm");
}
}
static const KDF&
select_kdf(KDF::ID id)
{
switch (id) {
case KDF::ID::HKDF_SHA256:
return KDF::get<KDF::ID::HKDF_SHA256>();
case KDF::ID::HKDF_SHA384:
return KDF::get<KDF::ID::HKDF_SHA384>();
case KDF::ID::HKDF_SHA512:
return KDF::get<KDF::ID::HKDF_SHA512>();
default:
throw std::runtime_error("Unsupported algorithm");
}
}
static const AEAD&
select_aead(AEAD::ID id)
{
switch (id) {
case AEAD::ID::AES_128_GCM:
return AEAD::get<AEAD::ID::AES_128_GCM>();
case AEAD::ID::AES_256_GCM:
return AEAD::get<AEAD::ID::AES_256_GCM>();
case AEAD::ID::CHACHA20_POLY1305:
return AEAD::get<AEAD::ID::CHACHA20_POLY1305>();
case AEAD::ID::export_only:
return AEAD::get<AEAD::ID::export_only>();
default:
throw std::runtime_error("Unsupported algorithm");
}
}
HPKE::HPKE(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id)
: suite(suite_id(kem_id, kdf_id, aead_id))
, kem(select_kem(kem_id))
, kdf(select_kdf(kdf_id))
, aead(select_aead(aead_id))
{
}
HPKE::SenderInfo
HPKE::setup_base_s(const KEM::PublicKey& pkR, const bytes& info) const
{
auto [shared_secret, enc] = kem.encap(pkR);
auto ctx =
key_schedule(Mode::base, shared_secret, info, default_psk, default_psk_id);
return std::make_pair(enc, SenderContext(std::move(ctx)));
}
ReceiverContext
HPKE::setup_base_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info) const
{
auto pkRm = kem.serialize(*skR.public_key());
auto shared_secret = kem.decap(enc, skR);
auto ctx =
key_schedule(Mode::base, shared_secret, info, default_psk, default_psk_id);
return { std::move(ctx) };
}
HPKE::SenderInfo
HPKE::setup_psk_s(const KEM::PublicKey& pkR,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const
{
auto [shared_secret, enc] = kem.encap(pkR);
auto ctx = key_schedule(Mode::psk, shared_secret, info, psk, psk_id);
return std::make_pair(enc, SenderContext(std::move(ctx)));
}
ReceiverContext
HPKE::setup_psk_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const
{
auto shared_secret = kem.decap(enc, skR);
auto ctx = key_schedule(Mode::psk, shared_secret, info, psk, psk_id);
return { std::move(ctx) };
}
HPKE::SenderInfo
HPKE::setup_auth_s(const KEM::PublicKey& pkR,
const bytes& info,
const KEM::PrivateKey& skS) const
{
auto [shared_secret, enc] = kem.auth_encap(pkR, skS);
auto ctx =
key_schedule(Mode::auth, shared_secret, info, default_psk, default_psk_id);
return std::make_pair(enc, SenderContext(std::move(ctx)));
}
ReceiverContext
HPKE::setup_auth_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const KEM::PublicKey& pkS) const
{
auto shared_secret = kem.auth_decap(enc, pkS, skR);
auto ctx =
key_schedule(Mode::auth, shared_secret, info, default_psk, default_psk_id);
return { std::move(ctx) };
}
HPKE::SenderInfo
HPKE::setup_auth_psk_s(const KEM::PublicKey& pkR,
const bytes& info,
const bytes& psk,
const bytes& psk_id,
const KEM::PrivateKey& skS) const
{
auto [shared_secret, enc] = kem.auth_encap(pkR, skS);
auto ctx = key_schedule(Mode::auth_psk, shared_secret, info, psk, psk_id);
return std::make_pair(enc, SenderContext(std::move(ctx)));
}
ReceiverContext
HPKE::setup_auth_psk_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const bytes& psk,
const bytes& psk_id,
const KEM::PublicKey& pkS) const
{
auto shared_secret = kem.auth_decap(enc, pkS, skR);
auto ctx = key_schedule(Mode::auth_psk, shared_secret, info, psk, psk_id);
return { std::move(ctx) };
}
bool
HPKE::verify_psk_inputs(Mode mode, const bytes& psk, const bytes& psk_id)
{
auto got_psk = (psk != default_psk);
auto got_psk_id = (psk_id != default_psk_id);
if (got_psk != got_psk_id) {
return false;
}
return (!got_psk && (mode == Mode::base || mode == Mode::auth)) ||
(got_psk && (mode == Mode::psk || mode == Mode::auth_psk));
}
Context
HPKE::key_schedule(Mode mode,
const bytes& shared_secret,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const
{
if (!verify_psk_inputs(mode, psk, psk_id)) {
throw std::runtime_error("Invalid PSK inputs");
}
auto psk_id_hash =
kdf.labeled_extract(suite, {}, label_psk_id_hash(), psk_id);
auto info_hash = kdf.labeled_extract(suite, {}, label_info_hash(), info);
auto mode_bytes = bytes{ uint8_t(mode) };
auto key_schedule_context = mode_bytes + psk_id_hash + info_hash;
auto secret = kdf.labeled_extract(suite, shared_secret, label_secret(), psk);
auto key = kdf.labeled_expand(
suite, secret, label_key(), key_schedule_context, aead.key_size);
auto nonce = kdf.labeled_expand(
suite, secret, label_base_nonce(), key_schedule_context, aead.nonce_size);
auto exporter_secret = kdf.labeled_expand(
suite, secret, label_exp(), key_schedule_context, kdf.hash_size);
return { suite, key, nonce, exporter_secret, kdf, aead };
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,160 @@
#include "openssl_common.h"
#include <openssl/ec.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/hmac.h>
#include <openssl/x509.h>
#include <openssl/x509v3.h>
#if defined(WITH_OPENSSL3)
#include <openssl/param_build.h>
#endif
namespace mlspp::hpke {
template<>
void
typed_delete(EVP_CIPHER_CTX* ptr)
{
EVP_CIPHER_CTX_free(ptr);
}
#if WITH_BORINGSSL
template<>
void
typed_delete(EVP_AEAD_CTX* ptr)
{
EVP_AEAD_CTX_free(ptr);
}
#endif
template<>
void
typed_delete(EVP_PKEY_CTX* ptr)
{
EVP_PKEY_CTX_free(ptr);
}
template<>
void
typed_delete(EVP_MD_CTX* ptr)
{
EVP_MD_CTX_free(ptr);
}
#if !defined(WITH_OPENSSL3)
template<>
void
typed_delete(HMAC_CTX* ptr)
{
HMAC_CTX_free(ptr);
}
#endif
template<>
void
typed_delete(EVP_PKEY* ptr)
{
EVP_PKEY_free(ptr);
}
template<>
void
typed_delete(BIGNUM* ptr)
{
BN_free(ptr);
}
template<>
void
typed_delete(EC_POINT* ptr)
{
EC_POINT_free(ptr);
}
#if !defined(WITH_OPENSSL3)
template<>
void
typed_delete(EC_KEY* ptr)
{
EC_KEY_free(ptr);
}
#endif
#if defined(WITH_OPENSSL3)
template<>
void
typed_delete(EVP_MAC* ptr)
{
EVP_MAC_free(ptr);
}
template<>
void
typed_delete(EVP_MAC_CTX* ptr)
{
EVP_MAC_CTX_free(ptr);
}
template<>
void
typed_delete(EC_GROUP* ptr)
{
EC_GROUP_free(ptr);
}
template<>
void
typed_delete(OSSL_PARAM_BLD* ptr)
{
OSSL_PARAM_BLD_free(ptr);
}
template<>
void
typed_delete(OSSL_PARAM* ptr)
{
OSSL_PARAM_free(ptr);
}
#endif
template<>
void
typed_delete(X509* ptr)
{
X509_free(ptr);
}
template<>
void
typed_delete(STACK_OF(GENERAL_NAME) * ptr)
{
sk_GENERAL_NAME_pop_free(ptr, GENERAL_NAME_free);
}
template<>
void
typed_delete(BIO* ptr)
{
BIO_vfree(ptr);
}
template<>
void
typed_delete(ASN1_TIME* ptr)
{
ASN1_TIME_free(ptr);
}
///
/// Map OpenSSL errors to C++ exceptions
///
std::runtime_error
openssl_error()
{
auto code = ERR_get_error();
return std::runtime_error(ERR_error_string(code, nullptr));
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,25 @@
#pragma once
#include <hpke/hpke.h>
#include <stdexcept>
namespace mlspp::hpke {
template<typename T>
void
typed_delete(T* ptr);
template<typename T>
using typed_unique_ptr = std::unique_ptr<T, decltype(&typed_delete<T>)>;
template<typename T>
typed_unique_ptr<T>
make_typed_unique(T* ptr)
{
return typed_unique_ptr<T>(ptr, typed_delete<T>);
}
std::runtime_error
openssl_error();
} // namespace mlspp::hpke

View File

@@ -0,0 +1,19 @@
#include <hpke/random.h>
#include "openssl_common.h"
#include <openssl/rand.h>
namespace mlspp::hpke {
bytes
random_bytes(size_t size)
{
auto rand = bytes(size);
if (1 != RAND_bytes(rand.data(), static_cast<int>(size))) {
throw openssl_error();
}
return rand;
}
} // namespace mlspp::hpke

207
DPP/mlspp/lib/hpke/src/rsa.cpp Executable file
View File

@@ -0,0 +1,207 @@
#include "rsa.h"
#include "common.h"
#include "openssl/rsa.h"
#include "openssl_common.h"
namespace mlspp::hpke {
std::unique_ptr<Signature::PrivateKey>
RSASignature::generate_key_pair() const
{
throw std::runtime_error("Not implemented");
}
std::unique_ptr<Signature::PrivateKey>
RSASignature::derive_key_pair(const bytes& /*ikm*/) const
{
throw std::runtime_error("Not implemented");
}
std::unique_ptr<Signature::PrivateKey>
RSASignature::generate_key_pair(size_t bits)
{
auto ctx = make_typed_unique(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr));
if (ctx == nullptr) {
throw openssl_error();
}
if (EVP_PKEY_keygen_init(ctx.get()) <= 0) {
throw openssl_error();
}
// NOLINTNEXTLINE(hicpp-signed-bitwise)
if (EVP_PKEY_CTX_set_rsa_keygen_bits(ctx.get(), static_cast<int>(bits)) <=
0) {
throw openssl_error();
}
auto* pkey = static_cast<EVP_PKEY*>(nullptr);
if (EVP_PKEY_keygen(ctx.get(), &pkey) <= 0) {
throw openssl_error();
}
return std::make_unique<PrivateKey>(pkey);
}
// TODO(rlb): Implement derive() with sizes
bytes
RSASignature::serialize(const Signature::PublicKey& pk) const
{
const auto& rpk = dynamic_cast<const PublicKey&>(pk);
const int len = i2d_PublicKey(rpk.pkey.get(), nullptr);
auto raw = bytes(len);
auto* data_ptr = raw.data();
if (len != i2d_PublicKey(rpk.pkey.get(), &data_ptr)) {
throw openssl_error();
}
return raw;
}
std::unique_ptr<Signature::PublicKey>
RSASignature::deserialize(const bytes& enc) const
{
const auto* data_ptr = enc.data();
auto* pkey = d2i_PublicKey(
EVP_PKEY_RSA, nullptr, &data_ptr, static_cast<int>(enc.size()));
if (pkey == nullptr) {
throw openssl_error();
}
return std::make_unique<RSASignature::PublicKey>(pkey);
}
bytes
RSASignature::serialize_private(const Signature::PrivateKey& sk) const
{
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
const int len = i2d_PrivateKey(rsk.pkey.get(), nullptr);
auto raw = bytes(len);
auto* data_ptr = raw.data();
if (len != i2d_PrivateKey(rsk.pkey.get(), &data_ptr)) {
throw openssl_error();
}
return raw;
}
std::unique_ptr<Signature::PrivateKey>
RSASignature::deserialize_private(const bytes& skm) const
{
const auto* data_ptr = skm.data();
auto* pkey = d2i_PrivateKey(
EVP_PKEY_RSA, nullptr, &data_ptr, static_cast<int>(skm.size()));
if (pkey == nullptr) {
throw openssl_error();
}
return std::make_unique<RSASignature::PrivateKey>(pkey);
}
bytes
RSASignature::sign(const bytes& data, const Signature::PrivateKey& sk) const
{
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
auto ctx = make_typed_unique(EVP_MD_CTX_create());
if (ctx == nullptr) {
throw openssl_error();
}
if (1 !=
EVP_DigestSignInit(ctx.get(), nullptr, md, nullptr, rsk.pkey.get())) {
throw openssl_error();
}
size_t siglen = EVP_PKEY_size(rsk.pkey.get());
bytes sig(siglen);
if (1 != EVP_DigestSign(
ctx.get(), sig.data(), &siglen, data.data(), data.size())) {
throw openssl_error();
}
sig.resize(siglen);
return sig;
}
bool
RSASignature::verify(const bytes& data,
const bytes& sig,
const Signature::PublicKey& pk) const
{
const auto& rpk = dynamic_cast<const PublicKey&>(pk);
auto ctx = make_typed_unique(EVP_MD_CTX_create());
if (ctx == nullptr) {
throw openssl_error();
}
if (1 !=
EVP_DigestVerifyInit(ctx.get(), nullptr, md, nullptr, rpk.pkey.get())) {
throw openssl_error();
}
auto rv = EVP_DigestVerify(
ctx.get(), sig.data(), sig.size(), data.data(), data.size());
return rv == 1;
}
// TODO(RLB) Implement these methods. No concrete need, but might be nice for
// completeness.
std::unique_ptr<Signature::PrivateKey>
RSASignature::import_jwk_private(const std::string& /* json_str */) const
{
throw std::runtime_error("not implemented");
}
std::unique_ptr<Signature::PublicKey>
RSASignature::import_jwk(const std::string& /* json_str */) const
{
throw std::runtime_error("not implemented");
}
std::string
RSASignature::export_jwk_private(const Signature::PrivateKey& /* sk */) const
{
throw std::runtime_error("not implemented");
}
std::string
RSASignature::export_jwk(const Signature::PublicKey& /* pk */) const
{
throw std::runtime_error("not implemented");
}
const EVP_MD*
RSASignature::digest_to_md(Digest::ID digest)
{
// NOLINTNEXTLINE(hicpp-multiway-paths-covered)
switch (digest) {
case Digest::ID::SHA256:
return EVP_sha256();
case Digest::ID::SHA384:
return EVP_sha384();
case Digest::ID::SHA512:
return EVP_sha512();
default:
throw std::runtime_error("Unsupported digest");
}
}
Signature::ID
RSASignature::digest_to_sig(Digest::ID digest)
{
// NOLINTNEXTLINE(hicpp-multiway-paths-covered)
switch (digest) {
case Digest::ID::SHA256:
return Signature::ID::RSA_SHA256;
case Digest::ID::SHA384:
return Signature::ID::RSA_SHA384;
case Digest::ID::SHA512:
return Signature::ID::RSA_SHA512;
default:
throw std::runtime_error("Unsupported digest");
}
}
} // namespace mlspp::hpke

97
DPP/mlspp/lib/hpke/src/rsa.h Executable file
View File

@@ -0,0 +1,97 @@
#pragma once
#include <hpke/digest.h>
#include <hpke/hpke.h>
#include <hpke/signature.h>
#include "openssl_common.h"
#include <openssl/evp.h>
#include <openssl/rsa.h>
namespace mlspp::hpke {
// XXX(RLB): There is a lot of code in RSASignature that is duplicated in
// EVPGroup. I have allowed this duplication rather than factoring it out
// because I would like to be able to cleanly remove RSA later.
struct RSASignature : public Signature
{
struct PublicKey : public Signature::PublicKey
{
explicit PublicKey(EVP_PKEY* pkey_in)
: pkey(pkey_in, typed_delete<EVP_PKEY>)
{
}
~PublicKey() override = default;
typed_unique_ptr<EVP_PKEY> pkey;
};
struct PrivateKey : public Signature::PrivateKey
{
explicit PrivateKey(EVP_PKEY* pkey_in)
: pkey(pkey_in, typed_delete<EVP_PKEY>)
{
}
~PrivateKey() override = default;
std::unique_ptr<Signature::PublicKey> public_key() const override
{
if (1 != EVP_PKEY_up_ref(pkey.get())) {
throw openssl_error();
}
return std::make_unique<PublicKey>(pkey.get());
}
typed_unique_ptr<EVP_PKEY> pkey;
};
explicit RSASignature(Digest::ID digest)
: Signature(digest_to_sig(digest))
, md(digest_to_md(digest))
{
}
std::unique_ptr<Signature::PrivateKey> generate_key_pair() const override;
std::unique_ptr<Signature::PrivateKey> derive_key_pair(
const bytes& /*ikm*/) const override;
static std::unique_ptr<Signature::PrivateKey> generate_key_pair(size_t bits);
// TODO(rlb): Implement derive() with sizes
bytes serialize(const Signature::PublicKey& pk) const override;
std::unique_ptr<Signature::PublicKey> deserialize(
const bytes& enc) const override;
bytes serialize_private(const Signature::PrivateKey& sk) const override;
std::unique_ptr<Signature::PrivateKey> deserialize_private(
const bytes& skm) const override;
bytes sign(const bytes& data, const Signature::PrivateKey& sk) const override;
bool verify(const bytes& data,
const bytes& sig,
const Signature::PublicKey& pk) const override;
std::unique_ptr<Signature::PrivateKey> import_jwk_private(
const std::string& json_str) const override;
std::unique_ptr<Signature::PublicKey> import_jwk(
const std::string& json_str) const override;
std::string export_jwk_private(
const Signature::PrivateKey& sk) const override;
std::string export_jwk(const Signature::PublicKey& pk) const override;
private:
const EVP_MD* md;
static const EVP_MD* digest_to_md(Digest::ID digest);
static Signature::ID digest_to_sig(Digest::ID digest);
};
} // namespace mlspp::hpke

View File

@@ -0,0 +1,344 @@
#include <hpke/base64.h>
#include <hpke/digest.h>
#include <hpke/signature.h>
#include <string>
#include "dhkem.h"
#include "rsa.h"
#include <dpp/json.h>
#include <openssl/bio.h>
#include <openssl/bn.h>
#include <openssl/ec.h>
#include <openssl/evp.h>
#include <openssl/rsa.h>
using nlohmann::json;
namespace mlspp::hpke {
struct GroupSignature : public Signature
{
struct PrivateKey : public Signature::PrivateKey
{
explicit PrivateKey(Group::PrivateKey* group_priv_in)
: group_priv(group_priv_in)
{
}
std::unique_ptr<Signature::PublicKey> public_key() const override
{
return group_priv->public_key();
}
std::unique_ptr<Group::PrivateKey> group_priv;
};
static Signature::ID group_to_sig(Group::ID group_id)
{
switch (group_id) {
case Group::ID::P256:
return Signature::ID::P256_SHA256;
case Group::ID::P384:
return Signature::ID::P384_SHA384;
case Group::ID::P521:
return Signature::ID::P521_SHA512;
case Group::ID::Ed25519:
return Signature::ID::Ed25519;
#if !defined(WITH_BORINGSSL)
case Group::ID::Ed448:
return Signature::ID::Ed448;
#endif
default:
throw std::runtime_error("Unsupported group");
}
}
explicit GroupSignature(const Group& group_in)
: Signature(group_to_sig(group_in.id))
, group(group_in)
{
}
std::unique_ptr<Signature::PrivateKey> generate_key_pair() const override
{
return std::make_unique<PrivateKey>(group.generate_key_pair().release());
}
std::unique_ptr<Signature::PrivateKey> derive_key_pair(
const bytes& ikm) const override
{
return std::make_unique<PrivateKey>(
group.derive_key_pair({}, ikm).release());
}
bytes serialize(const Signature::PublicKey& pk) const override
{
const auto& rpk = dynamic_cast<const Group::PublicKey&>(pk);
return group.serialize(rpk);
}
std::unique_ptr<Signature::PublicKey> deserialize(
const bytes& enc) const override
{
return group.deserialize(enc);
}
bytes serialize_private(const Signature::PrivateKey& sk) const override
{
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
return group.serialize_private(*rsk.group_priv);
}
std::unique_ptr<Signature::PrivateKey> deserialize_private(
const bytes& skm) const override
{
return std::make_unique<PrivateKey>(
group.deserialize_private(skm).release());
}
bytes sign(const bytes& data, const Signature::PrivateKey& sk) const override
{
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
return group.sign(data, *rsk.group_priv);
}
bool verify(const bytes& data,
const bytes& sig,
const Signature::PublicKey& pk) const override
{
const auto& rpk = dynamic_cast<const Group::PublicKey&>(pk);
return group.verify(data, sig, rpk);
}
std::unique_ptr<Signature::PrivateKey> import_jwk_private(
const std::string& jwk_json) const override
{
const auto jwk = validate_jwk_json(jwk_json, true);
const auto d = from_base64url(jwk.at("d"));
auto gsk = group.deserialize_private(d);
return std::make_unique<PrivateKey>(gsk.release());
}
std::unique_ptr<Signature::PublicKey> import_jwk(
const std::string& jwk_json) const override
{
const auto jwk = validate_jwk_json(jwk_json, false);
const auto x = from_base64url(jwk.at("x"));
auto y = bytes{};
if (jwk.contains("y")) {
y = from_base64url(jwk.at("y"));
}
return group.public_key_from_coordinates(x, y);
}
std::string export_jwk(const Signature::PublicKey& pk) const override
{
const auto& gpk = dynamic_cast<const Group::PublicKey&>(pk);
const auto jwk_json = export_jwk_json(gpk);
return jwk_json.dump();
}
std::string export_jwk_private(const Signature::PrivateKey& sk) const override
{
const auto& gssk = dynamic_cast<const GroupSignature::PrivateKey&>(sk);
const auto& gsk = gssk.group_priv;
const auto gpk = gsk->public_key();
auto jwk_json = export_jwk_json(*gpk);
// encode the private key
const auto enc = serialize_private(sk);
jwk_json.emplace("d", to_base64url(enc));
return jwk_json.dump();
}
private:
const Group& group;
json validate_jwk_json(const std::string& jwk_json, bool private_key) const
{
json jwk = json::parse(jwk_json);
if (jwk.empty() || !jwk.contains("kty") || !jwk.contains("crv") ||
!jwk.contains("x") || (private_key && !jwk.contains("d"))) {
throw std::runtime_error("malformed JWK");
}
if (jwk.at("kty") != group.jwk_key_type) {
throw std::runtime_error("invalid JWK key type");
}
if (jwk.at("crv") != group.jwk_curve_name) {
throw std::runtime_error("invalid JWK curve");
}
return jwk;
}
json export_jwk_json(const Group::PublicKey& pk) const
{
const auto [x, y] = group.coordinates(pk);
json jwk = json::object({
{ "crv", group.jwk_curve_name },
{ "kty", group.jwk_key_type },
});
if (group.jwk_key_type == "EC") {
jwk.emplace("x", to_base64url(x));
jwk.emplace("y", to_base64url(y));
} else if (group.jwk_key_type == "OKP") {
jwk.emplace("x", to_base64url(x));
} else {
throw std::runtime_error("unknown key type");
}
return jwk;
}
};
template<>
const Signature&
Signature::get<Signature::ID::P256_SHA256>()
{
static const auto instance = GroupSignature(Group::get<Group::ID::P256>());
return instance;
}
template<>
const Signature&
Signature::get<Signature::ID::P384_SHA384>()
{
static const auto instance = GroupSignature(Group::get<Group::ID::P384>());
return instance;
}
template<>
const Signature&
Signature::get<Signature::ID::P521_SHA512>()
{
static const auto instance = GroupSignature(Group::get<Group::ID::P521>());
return instance;
}
template<>
const Signature&
Signature::get<Signature::ID::Ed25519>()
{
static const auto instance = GroupSignature(Group::get<Group::ID::Ed25519>());
return instance;
}
#if !defined(WITH_BORINGSSL)
template<>
const Signature&
Signature::get<Signature::ID::Ed448>()
{
static const auto instance = GroupSignature(Group::get<Group::ID::Ed448>());
return instance;
}
#endif
template<>
const Signature&
Signature::get<Signature::ID::RSA_SHA256>()
{
static const auto instance = RSASignature(Digest::ID::SHA256);
return instance;
}
template<>
const Signature&
Signature::get<Signature::ID::RSA_SHA384>()
{
static const auto instance = RSASignature(Digest::ID::SHA384);
return instance;
}
template<>
const Signature&
Signature::get<Signature::ID::RSA_SHA512>()
{
static const auto instance = RSASignature(Digest::ID::SHA512);
return instance;
}
Signature::Signature(Signature::ID id_in)
: id(id_in)
{
}
std::unique_ptr<Signature::PrivateKey>
Signature::generate_rsa(size_t bits)
{
return RSASignature::generate_key_pair(bits);
}
static const Signature&
sig_from_jwk(const std::string& jwk_json)
{
using KeyTypeAndCurve = std::tuple<std::string, std::string>;
static const auto alg_sig_map = std::map<KeyTypeAndCurve, const Signature&>
{
{ { "EC", "P-256" }, Signature::get<Signature::ID::P256_SHA256>() },
{ { "EC", "P-384" }, Signature::get<Signature::ID::P384_SHA384>() },
{ { "EC", "P-512" }, Signature::get<Signature::ID::P521_SHA512>() },
{ { "OKP", "Ed25519" }, Signature::get<Signature::ID::Ed25519>() },
#if !defined(WITH_BORINGSSL)
{ { "OKP", "Ed448" }, Signature::get<Signature::ID::Ed448>() },
#endif
// TODO(RLB): RSA
};
const auto jwk = json::parse(jwk_json);
const auto& kty = jwk.at("kty");
auto crv = std::string("");
if (jwk.contains("crv")) {
crv = jwk.at("crv");
}
const auto key = KeyTypeAndCurve{ kty, crv };
return alg_sig_map.at(key);
}
Signature::PrivateJWK
Signature::parse_jwk_private(const std::string& jwk_json)
{
// XXX(RLB): This JSON-parses the JWK twice. I'm assuming that this is a less
// bad cost than changing the import_jwk method signature to take `json`.
const auto& sig = sig_from_jwk(jwk_json);
const auto jwk = json::parse(jwk_json);
auto priv = sig.import_jwk_private(jwk_json);
auto kid = std::optional<std::string>{};
if (jwk.contains("kid")) {
kid = jwk.at("kid").get<std::string>();
}
return { sig, kid, std::move(priv) };
}
Signature::PublicJWK
Signature::parse_jwk(const std::string& jwk_json)
{
// XXX(RLB): Same double-parsing comment as with `parse_jwk_private`
const auto& sig = sig_from_jwk(jwk_json);
const auto jwk = json::parse(jwk_json);
auto pub = sig.import_jwk(jwk_json);
auto kid = std::optional<std::string>{};
if (jwk.contains("kid")) {
kid = jwk.at("kid").get<std::string>();
}
return { sig, kid, std::move(pub) };
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,401 @@
#include <hpke/base64.h>
#include <hpke/signature.h>
#include <hpke/userinfo_vc.h>
#include <dpp/json.h>
#include <tls/compat.h>
using nlohmann::json;
namespace mlspp::hpke {
static const std::string name_attr = "name";
static const std::string sub_attr = "sub";
static const std::string given_name_attr = "given_name";
static const std::string family_name_attr = "family_name";
static const std::string middle_name_attr = "middle_name";
static const std::string nickname_attr = "nickname";
static const std::string preferred_username_attr = "preferred_username";
static const std::string profile_attr = "profile";
static const std::string picture_attr = "picture";
static const std::string website_attr = "website";
static const std::string email_attr = "email";
static const std::string email_verified_attr = "email_verified";
static const std::string gender_attr = "gender";
static const std::string birthdate_attr = "birthdate";
static const std::string zoneinfo_attr = "zoneinfo";
static const std::string locale_attr = "locale";
static const std::string phone_number_attr = "phone_number";
static const std::string phone_number_verified_attr = "phone_number_verified";
static const std::string address_attr = "address";
static const std::string address_formatted_attr = "formatted";
static const std::string address_street_address_attr = "street_address";
static const std::string address_locality_attr = "locality";
static const std::string address_region_attr = "region";
static const std::string address_postal_code_attr = "postal_code";
static const std::string address_country_attr = "country";
static const std::string updated_at_attr = "updated_at";
template<typename T>
static std::optional<T>
get_optional(const json& json_object, const std::string& field_name)
{
if (!json_object.contains(field_name)) {
return std::nullopt;
}
return { json_object.at(field_name).get<T>() };
}
///
/// ParsedCredential
///
static const Signature&
signature_from_alg(const std::string& alg)
{
static const auto alg_sig_map = std::map<std::string, const Signature&>
{
{ "ES256", Signature::get<Signature::ID::P256_SHA256>() },
{ "ES384", Signature::get<Signature::ID::P384_SHA384>() },
{ "ES512", Signature::get<Signature::ID::P521_SHA512>() },
{ "Ed25519", Signature::get<Signature::ID::Ed25519>() },
#if !defined(WITH_BORINGSSL)
{ "Ed448", Signature::get<Signature::ID::Ed448>() },
#endif
{ "RS256", Signature::get<Signature::ID::RSA_SHA256>() },
{ "RS384", Signature::get<Signature::ID::RSA_SHA384>() },
{ "RS512", Signature::get<Signature::ID::RSA_SHA512>() },
};
return alg_sig_map.at(alg);
}
static std::chrono::system_clock::time_point
epoch_time(int64_t seconds_since_epoch)
{
const auto delta = std::chrono::seconds(seconds_since_epoch);
return std::chrono::system_clock::time_point(delta);
}
static bool
is_ecdsa(const Signature& sig)
{
return sig.id == Signature::ID::P256_SHA256 ||
sig.id == Signature::ID::P384_SHA384 ||
sig.id == Signature::ID::P521_SHA512;
}
// OpenSSL expects ECDSA signatures to be in DER form. JWS provides the
// signature in raw R||S form. So we need to do some manual DER encoding.
static bytes
jws_to_der_sig(const bytes& jws_sig)
{
// Inputs that are too large will result in invalid DER encodings with this
// code. At this size, the combination of the DER integer headers and the
// integer data will overflow the one-byte DER struct length.
static const auto max_sig_size = size_t(250);
if (jws_sig.size() > max_sig_size) {
throw std::runtime_error("JWS signature too large");
}
if (jws_sig.size() % 2 != 0) {
throw std::runtime_error("Malformed JWS signature");
}
const auto int_size = jws_sig.size() / 2;
const auto jws_sig_cut =
jws_sig.begin() + static_cast<std::ptrdiff_t>(int_size);
// Compute the encoded size of R and S integer data, adding a zero byte if
// needed to clear the sign bit
const auto r_big = (jws_sig.at(0) >= 0x80);
const auto s_big = (jws_sig.at(int_size) >= 0x80);
const auto r_size = int_size + (r_big ? 1 : 0);
const auto s_size = int_size + (s_big ? 1 : 0);
// Compute the size of the DER-encoded signature
static const auto int_header_size = 2;
const auto r_int_size = int_header_size + r_size;
const auto s_int_size = int_header_size + s_size;
const auto content_size = r_int_size + s_int_size;
const auto content_big = (content_size > 0x80);
auto der_header_size = 2 + (content_big ? 1 : 0);
const auto der_size = der_header_size + content_size;
// Allocate the DER buffer
auto der = bytes(der_size, 0);
// Write the header
der.at(0) = 0x30;
if (content_big) {
der.at(1) = 0x81;
der.at(2) = static_cast<uint8_t>(content_size);
} else {
der.at(1) = static_cast<uint8_t>(content_size);
}
// Write R, virtually padding with a zero byte if needed
const auto r_start = der_header_size;
const auto r_data_start = r_start + int_header_size + (r_big ? 1 : 0);
const auto r_data_begin =
der.begin() + static_cast<std::ptrdiff_t>(r_data_start);
der.at(r_start) = 0x02;
der.at(r_start + 1) = static_cast<uint8_t>(r_size);
std::copy(jws_sig.begin(), jws_sig_cut, r_data_begin);
// Write S, virtually padding with a zero byte if needed
const auto s_start = der_header_size + r_int_size;
const auto s_data_start = s_start + int_header_size + (s_big ? 1 : 0);
const auto s_data_begin =
der.begin() + static_cast<std::ptrdiff_t>(s_data_start);
der.at(s_start) = 0x02;
der.at(s_start + 1) = static_cast<uint8_t>(s_size);
std::copy(jws_sig_cut, jws_sig.end(), s_data_begin);
return der;
}
struct UserInfoVC::ParsedCredential
{
// Header fields
const Signature& signature_algorithm; // `alg`
std::optional<std::string> key_id; // `kid`
// Top-level Payload fields
std::string issuer; // `iss`
std::chrono::system_clock::time_point not_before; // `nbf`
std::chrono::system_clock::time_point not_after; // `exp`
// Credential subject fields
UserInfoClaims credential_subject;
Signature::PublicJWK public_key;
// Signature verification information
bytes to_be_signed;
bytes signature;
ParsedCredential(const Signature& signature_algorithm_in,
std::optional<std::string> key_id_in,
std::string issuer_in,
std::chrono::system_clock::time_point not_before_in,
std::chrono::system_clock::time_point not_after_in,
UserInfoClaims credential_subject_in,
Signature::PublicJWK&& public_key_in,
bytes to_be_signed_in,
bytes signature_in)
: signature_algorithm(signature_algorithm_in)
, key_id(std::move(key_id_in))
, issuer(std::move(issuer_in))
, not_before(not_before_in)
, not_after(not_after_in)
, credential_subject(std::move(credential_subject_in))
, public_key(std::move(public_key_in))
, to_be_signed(std::move(to_be_signed_in))
, signature(std::move(signature_in))
{
}
static std::shared_ptr<ParsedCredential> parse(const std::string& jwt)
{
// Split the JWT into its header, payload, and signature
const auto first_dot = jwt.find_first_of('.');
const auto last_dot = jwt.find_last_of('.');
if (first_dot == std::string::npos || last_dot == std::string::npos ||
first_dot == last_dot || last_dot > jwt.length() - 2) {
throw std::runtime_error("malformed JWT; not enough '.' characters");
}
const auto header_b64 = jwt.substr(0, first_dot);
const auto payload_b64 =
jwt.substr(first_dot + 1, last_dot - first_dot - 1);
const auto signature_b64 = jwt.substr(last_dot + 1);
// Parse the components
const auto header = json::parse(to_ascii(from_base64url(header_b64)));
const auto payload = json::parse(to_ascii(from_base64url(payload_b64)));
// Prepare the validation inputs
const auto hdr = header.at("alg");
const auto& sig = signature_from_alg(hdr);
const auto to_be_signed = from_ascii(header_b64 + "." + payload_b64);
auto signature = from_base64url(signature_b64);
if (is_ecdsa(sig)) {
signature = jws_to_der_sig(signature);
}
auto kid = std::optional<std::string>{};
if (header.contains("kid")) {
kid = header.at("kid").get<std::string>();
}
// Verify the VC parts
const auto& vc = payload.at("vc");
static const auto context =
std::vector<std::string>{ { "https://www.w3.org/2018/credentials/v1" } };
const auto vc_context = vc.at("@context").get<std::vector<std::string>>();
if (vc_context != context) {
throw std::runtime_error("malformed VC: incorrect context value");
}
static const auto type = std::vector<std::string>{
"VerifiableCredential",
"UserInfoCredential",
};
if (vc.at("type") != type) {
throw std::runtime_error("malformed VC: incorrect type value");
}
// Parse the subject public key
static const std::string did_jwk_prefix = "did:jwk:";
const auto id = vc.at("credentialSubject").at("id").get<std::string>();
if (id.find(did_jwk_prefix) != 0) {
throw std::runtime_error("malformed UserInfo VC: ID is not did:jwk");
}
const auto jwk = to_ascii(from_base64url(id.substr(did_jwk_prefix.size())));
auto public_key = Signature::parse_jwk(jwk);
// Extract the salient parts
return std::make_shared<ParsedCredential>(
sig,
kid,
payload.at("iss"),
epoch_time(payload.at("nbf").get<int64_t>()),
epoch_time(payload.at("exp").get<int64_t>()),
UserInfoClaims::from_json(vc.at("credentialSubject").dump()),
std::move(public_key),
to_be_signed,
signature);
}
bool verify(const Signature::PublicKey& issuer_key)
{
return signature_algorithm.verify(to_be_signed, signature, issuer_key);
}
};
///
/// UserInfoClaims
///
UserInfoClaims
UserInfoClaims::from_json(const std::string& cred_subject)
{
const auto& cred_subject_json = nlohmann::json::parse(cred_subject);
std::optional<UserInfoClaimsAddress> address_opt = {};
if (cred_subject_json.contains(address_attr)) {
auto address_json = cred_subject_json.at(address_attr);
address_opt = {
get_optional<std::string>(address_json, address_formatted_attr),
get_optional<std::string>(address_json, address_street_address_attr),
get_optional<std::string>(address_json, address_locality_attr),
get_optional<std::string>(address_json, address_region_attr),
get_optional<std::string>(address_json, address_postal_code_attr),
get_optional<std::string>(address_json, address_country_attr)
};
}
return {
get_optional<std::string>(cred_subject_json, sub_attr),
get_optional<std::string>(cred_subject_json, name_attr),
get_optional<std::string>(cred_subject_json, given_name_attr),
get_optional<std::string>(cred_subject_json, family_name_attr),
get_optional<std::string>(cred_subject_json, middle_name_attr),
get_optional<std::string>(cred_subject_json, nickname_attr),
get_optional<std::string>(cred_subject_json, preferred_username_attr),
get_optional<std::string>(cred_subject_json, profile_attr),
get_optional<std::string>(cred_subject_json, picture_attr),
get_optional<std::string>(cred_subject_json, website_attr),
get_optional<std::string>(cred_subject_json, email_attr),
get_optional<bool>(cred_subject_json, email_verified_attr),
get_optional<std::string>(cred_subject_json, gender_attr),
get_optional<std::string>(cred_subject_json, birthdate_attr),
get_optional<std::string>(cred_subject_json, zoneinfo_attr),
get_optional<std::string>(cred_subject_json, locale_attr),
get_optional<std::string>(cred_subject_json, phone_number_attr),
get_optional<bool>(cred_subject_json, phone_number_verified_attr),
address_opt,
get_optional<uint64_t>(cred_subject_json, updated_at_attr),
};
}
///
/// UserInfoVC
///
UserInfoVC::UserInfoVC(std::string jwt)
: parsed_cred(ParsedCredential::parse(jwt))
, raw(std::move(jwt))
{
}
const Signature&
UserInfoVC::signature_algorithm() const
{
return parsed_cred->signature_algorithm;
}
std::string
UserInfoVC::issuer() const
{
return parsed_cred->issuer;
}
std::optional<std::string>
UserInfoVC::key_id() const
{
return parsed_cred->key_id;
}
bool
UserInfoVC::valid_from(const Signature::PublicKey& issuer_key) const
{
return parsed_cred->verify(issuer_key);
}
const std::string&
UserInfoVC::raw_credential() const
{
return raw;
}
const UserInfoClaims&
UserInfoVC::subject() const
{
return parsed_cred->credential_subject;
}
std::chrono::system_clock::time_point
UserInfoVC::not_before() const
{
return parsed_cred->not_before;
}
std::chrono::system_clock::time_point
UserInfoVC::not_after() const
{
return parsed_cred->not_after;
}
const Signature::PublicJWK&
UserInfoVC::public_key() const
{
return parsed_cred->public_key;
}
bool
operator==(const UserInfoVC& lhs, const UserInfoVC& rhs)
{
return lhs.raw == rhs.raw;
}
} // namespace mlspp::hpke