add basic discord support

This commit is contained in:
2024-11-09 01:58:06 +03:00
parent c3e7a9c92d
commit 8965b7ee90
869 changed files with 191278 additions and 7 deletions

13
DPP/mlspp/src/common.cpp Executable file
View File

@@ -0,0 +1,13 @@
#include "mls/common.h"
namespace mlspp {
uint64_t
seconds_since_epoch()
{
// TODO(RLB) This should use std::chrono, but that seems not to be available
// on some platforms.
return std::time(nullptr);
}
} // namespace mlspp

443
DPP/mlspp/src/core_types.cpp Executable file
View File

@@ -0,0 +1,443 @@
#include "mls/core_types.h"
#include "mls/messages.h"
#include "grease.h"
#include <set>
namespace mlspp {
///
/// Extensions
///
const Extension::Type RequiredCapabilitiesExtension::type =
ExtensionType::required_capabilities;
const Extension::Type ApplicationIDExtension::type =
ExtensionType::application_id;
const std::array<uint16_t, 5> default_extensions = {
ExtensionType::application_id, ExtensionType::ratchet_tree,
ExtensionType::required_capabilities, ExtensionType::external_pub,
ExtensionType::external_senders,
};
const std::array<uint16_t, 8> default_proposals = {
ProposalType::add,
ProposalType::update,
ProposalType::remove,
ProposalType::psk,
ProposalType::reinit,
ProposalType::external_init,
ProposalType::group_context_extensions,
};
const std::array<ProtocolVersion, 1> all_supported_versions = {
ProtocolVersion::mls10
};
const std::array<CipherSuite::ID, 6> all_supported_ciphersuites = {
CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519,
CipherSuite::ID::P256_AES128GCM_SHA256_P256,
CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519,
CipherSuite::ID::X448_AES256GCM_SHA512_Ed448,
CipherSuite::ID::P521_AES256GCM_SHA512_P521,
CipherSuite::ID::X448_CHACHA20POLY1305_SHA512_Ed448,
};
const std::array<CredentialType, 4> all_supported_credentials = {
CredentialType::basic,
CredentialType::x509,
CredentialType::userinfo_vc_draft_00,
CredentialType::multi_draft_00
};
Capabilities
Capabilities::create_default()
{
return {
{ all_supported_versions.begin(), all_supported_versions.end() },
{ all_supported_ciphersuites.begin(), all_supported_ciphersuites.end() },
{ /* No non-default extensions */ },
{ /* No non-default proposals */ },
{ all_supported_credentials.begin(), all_supported_credentials.end() },
};
}
bool
Capabilities::extensions_supported(
const std::vector<Extension::Type>& required) const
{
return stdx::all_of(required, [&](Extension::Type type) {
if (stdx::contains(default_extensions, type)) {
return true;
}
return stdx::contains(extensions, type);
});
}
bool
Capabilities::proposals_supported(
const std::vector<Proposal::Type>& required) const
{
return stdx::all_of(required, [&](Proposal::Type type) {
if (stdx::contains(default_proposals, type)) {
return true;
}
return stdx::contains(proposals, type);
});
}
bool
Capabilities::credential_supported(const Credential& credential) const
{
return stdx::contains(credentials, credential.type());
}
Lifetime
Lifetime::create_default()
{
return Lifetime{ 0x0000000000000000, 0xffffffffffffffff };
}
void
ExtensionList::add(uint16_t type, bytes data)
{
auto curr = stdx::find_if(
extensions, [&](const Extension& ext) -> bool { return ext.type == type; });
if (curr != extensions.end()) {
curr->data = std::move(data);
return;
}
extensions.push_back({ type, std::move(data) });
}
bool
ExtensionList::has(uint16_t type) const
{
return stdx::any_of(extensions,
[&](const Extension& ext) { return ext.type == type; });
}
///
/// LeafNode
///
LeafNode::LeafNode(CipherSuite cipher_suite,
HPKEPublicKey encryption_key_in,
SignaturePublicKey signature_key_in,
Credential credential_in,
Capabilities capabilities_in,
Lifetime lifetime_in,
ExtensionList extensions_in,
const SignaturePrivateKey& sig_priv)
: encryption_key(std::move(encryption_key_in))
, signature_key(std::move(signature_key_in))
, credential(std::move(credential_in))
, capabilities(std::move(capabilities_in))
, content(lifetime_in)
, extensions(std::move(extensions_in))
{
grease(extensions);
grease(capabilities, extensions);
sign(cipher_suite, sig_priv, std::nullopt);
}
void
LeafNode::set_capabilities(Capabilities capabilities_in)
{
capabilities = std::move(capabilities_in);
grease(capabilities, extensions);
}
LeafNode
LeafNode::for_update(CipherSuite cipher_suite,
const bytes& group_id,
LeafIndex leaf_index,
HPKEPublicKey encryption_key_in,
const LeafNodeOptions& opts,
const SignaturePrivateKey& sig_priv) const
{
auto clone = clone_with_options(std::move(encryption_key_in), opts);
clone.content = Empty{};
clone.sign(cipher_suite, sig_priv, { { group_id, leaf_index } });
return clone;
}
LeafNode
LeafNode::for_commit(CipherSuite cipher_suite,
const bytes& group_id,
LeafIndex leaf_index,
HPKEPublicKey encryption_key_in,
const bytes& parent_hash,
const LeafNodeOptions& opts,
const SignaturePrivateKey& sig_priv) const
{
auto clone = clone_with_options(std::move(encryption_key_in), opts);
clone.content = ParentHash{ parent_hash };
clone.sign(cipher_suite, sig_priv, { { group_id, leaf_index } });
return clone;
}
LeafNodeSource
LeafNode::source() const
{
return tls::variant<LeafNodeSource>::type(content);
}
void
LeafNode::sign(CipherSuite cipher_suite,
const SignaturePrivateKey& sig_priv,
const std::optional<MemberBinding>& binding)
{
const auto tbs = to_be_signed(binding);
if (sig_priv.public_key != signature_key) {
throw InvalidParameterError("Signature key mismatch");
}
if (!credential.valid_for(signature_key)) {
throw InvalidParameterError("Credential not valid for signature key");
}
signature = sig_priv.sign(cipher_suite, sign_label::leaf_node, tbs);
}
bool
LeafNode::verify(CipherSuite cipher_suite,
const std::optional<MemberBinding>& binding) const
{
const auto tbs = to_be_signed(binding);
if (CredentialType::x509 == credential.type()) {
const auto& cred = credential.get<X509Credential>();
if (cred.signature_scheme() !=
tls_signature_scheme(cipher_suite.sig().id)) {
throw std::runtime_error("Signature algorithm invalid");
}
}
return signature_key.verify(
cipher_suite, sign_label::leaf_node, tbs, signature);
}
bool
LeafNode::verify_expiry(uint64_t now) const
{
const auto valid = overloaded{
[now](const Lifetime& lt) {
return lt.not_before <= now && now <= lt.not_after;
},
[](const auto& /* other */) { return false; },
};
return var::visit(valid, content);
}
bool
LeafNode::verify_extension_support(const ExtensionList& ext_list) const
{
// Verify that extensions in the list are supported
auto ext_types = stdx::transform<Extension::Type>(
ext_list.extensions, [](const auto& ext) { return ext.type; });
if (!capabilities.extensions_supported(ext_types)) {
return false;
}
// If there's a RequiredCapabilities extension, verify support
const auto maybe_req_capas = ext_list.find<RequiredCapabilitiesExtension>();
if (!maybe_req_capas) {
return true;
}
const auto& req_capas = opt::get(maybe_req_capas);
return capabilities.extensions_supported(req_capas.extensions) &&
capabilities.proposals_supported(req_capas.proposals);
}
LeafNode
LeafNode::clone_with_options(HPKEPublicKey encryption_key_in,
const LeafNodeOptions& opts) const
{
auto clone = *this;
clone.encryption_key = std::move(encryption_key_in);
if (opts.credential) {
clone.credential = opt::get(opts.credential);
}
if (opts.capabilities) {
clone.capabilities = opt::get(opts.capabilities);
}
if (opts.extensions) {
clone.extensions = opt::get(opts.extensions);
}
return clone;
}
// struct {
// HPKEPublicKey encryption_key;
// SignaturePublicKey signature_key;
// Credential credential;
// Capabilities capabilities;
//
// LeafNodeSource leaf_node_source;
// select (leaf_node_source) {
// case key_package:
// Lifetime lifetime;
//
// case update:
// struct{};
//
// case commit:
// opaque parent_hash<V>;
// }
//
// Extension extensions<V>;
//
// select (leaf_node_source) {
// case key_package:
// struct{};
//
// case update:
// opaque group_id<V>;
//
// case commit:
// opaque group_id<V>;
// }
// } LeafNodeTBS;
struct LeafNodeTBS
{
const HPKEPublicKey& encryption_key;
const SignaturePublicKey& signature_key;
const Credential& credential;
const Capabilities& capabilities;
const var::variant<Lifetime, Empty, ParentHash>& content;
const ExtensionList& extensions;
TLS_SERIALIZABLE(encryption_key,
signature_key,
credential,
capabilities,
content,
extensions)
TLS_TRAITS(tls::pass,
tls::pass,
tls::pass,
tls::pass,
tls::variant<LeafNodeSource>,
tls::pass)
};
bytes
LeafNode::to_be_signed(const std::optional<MemberBinding>& binding) const
{
tls::ostream w;
w << LeafNodeTBS{
encryption_key, signature_key, credential,
capabilities, content, extensions,
};
switch (source()) {
case LeafNodeSource::key_package:
break;
case LeafNodeSource::update:
case LeafNodeSource::commit:
w << opt::get(binding);
}
return w.bytes();
}
///
/// NodeType, ParentNode, and KeyPackage
///
bytes
ParentNode::hash(CipherSuite suite) const
{
return suite.digest().hash(tls::marshal(this));
}
KeyPackage::KeyPackage()
: version(ProtocolVersion::mls10)
, cipher_suite(CipherSuite::ID::unknown)
{
}
KeyPackage::KeyPackage(CipherSuite suite_in,
HPKEPublicKey init_key_in,
LeafNode leaf_node_in,
ExtensionList extensions_in,
const SignaturePrivateKey& sig_priv_in)
: version(ProtocolVersion::mls10)
, cipher_suite(suite_in)
, init_key(std::move(init_key_in))
, leaf_node(std::move(leaf_node_in))
, extensions(std::move(extensions_in))
{
grease(extensions);
sign(sig_priv_in);
}
KeyPackageRef
KeyPackage::ref() const
{
return cipher_suite.ref(*this);
}
void
KeyPackage::sign(const SignaturePrivateKey& sig_priv)
{
auto tbs = to_be_signed();
signature = sig_priv.sign(cipher_suite, sign_label::key_package, tbs);
}
bool
KeyPackage::verify() const
{
// Verify the inner leaf node
if (!leaf_node.verify(cipher_suite, std::nullopt)) {
return false;
}
// Check that the inner leaf node is intended for use in a KeyPackage
if (leaf_node.source() != LeafNodeSource::key_package) {
return false;
}
// Verify the KeyPackage
const auto tbs = to_be_signed();
if (CredentialType::x509 == leaf_node.credential.type()) {
const auto& cred = leaf_node.credential.get<X509Credential>();
if (cred.signature_scheme() !=
tls_signature_scheme(cipher_suite.sig().id)) {
throw std::runtime_error("Signature algorithm invalid");
}
}
return leaf_node.signature_key.verify(
cipher_suite, sign_label::key_package, tbs, signature);
}
bytes
KeyPackage::to_be_signed() const
{
tls::ostream out;
out << version << cipher_suite << init_key << leaf_node << extensions;
return out.bytes();
}
} // namespace mlspp

298
DPP/mlspp/src/credential.cpp Executable file
View File

@@ -0,0 +1,298 @@
#include <hpke/certificate.h>
#include <hpke/userinfo_vc.h>
#include <mls/credential.h>
#include <tls/tls_syntax.h>
namespace mlspp {
///
/// X509Credential
///
using mlspp::hpke::Certificate; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::Signature; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::UserInfoVC; // NOLINT(misc-unused-using-decls)
static const Signature&
find_signature(Signature::ID id)
{
switch (id) {
case Signature::ID::P256_SHA256:
return Signature::get<Signature::ID::P256_SHA256>();
case Signature::ID::P384_SHA384:
return Signature::get<Signature::ID::P384_SHA384>();
case Signature::ID::P521_SHA512:
return Signature::get<Signature::ID::P521_SHA512>();
case Signature::ID::Ed25519:
return Signature::get<Signature::ID::Ed25519>();
#if !defined(WITH_BORINGSSL)
case Signature::ID::Ed448:
return Signature::get<Signature::ID::Ed448>();
#endif
case Signature::ID::RSA_SHA256:
return Signature::get<Signature::ID::RSA_SHA256>();
default:
throw InvalidParameterError("Unsupported algorithm");
}
}
static std::vector<X509Credential::CertData>
bytes_to_x509_credential_data(const std::vector<bytes>& data_in)
{
return stdx::transform<X509Credential::CertData>(
data_in, [](const bytes& der) { return X509Credential::CertData{ der }; });
}
X509Credential::X509Credential(const std::vector<bytes>& der_chain_in)
: der_chain(bytes_to_x509_credential_data(der_chain_in))
{
if (der_chain.empty()) {
throw std::invalid_argument("empty certificate chain");
}
// Parse the chain
auto parsed = std::vector<Certificate>();
for (const auto& cert : der_chain) {
parsed.emplace_back(cert.data);
}
// first element represents leaf cert
const auto& sig = find_signature(parsed[0].public_key_algorithm());
const auto pub_data = sig.serialize(*parsed[0].public_key);
_signature_scheme = tls_signature_scheme(parsed[0].public_key_algorithm());
_public_key = SignaturePublicKey{ pub_data };
// verify chain for valid signatures
for (size_t i = 0; i < der_chain.size() - 1; i++) {
if (!parsed[i].valid_from(parsed[i + 1])) {
throw std::runtime_error("Certificate Chain validation failure");
}
}
}
SignatureScheme
X509Credential::signature_scheme() const
{
return _signature_scheme;
}
SignaturePublicKey
X509Credential::public_key() const
{
return _public_key;
}
bool
X509Credential::valid_for(const SignaturePublicKey& pub) const
{
return pub == public_key();
}
tls::ostream&
operator<<(tls::ostream& str, const X509Credential& obj)
{
return str << obj.der_chain;
}
tls::istream&
operator>>(tls::istream& str, X509Credential& obj)
{
auto der_chain = std::vector<X509Credential::CertData>{};
str >> der_chain;
auto der_in = stdx::transform<bytes>(
der_chain, [](const auto& cert_data) { return cert_data.data; });
obj = X509Credential(der_in);
return str;
}
bool
operator==(const X509Credential& lhs, const X509Credential& rhs)
{
return lhs.der_chain == rhs.der_chain;
}
///
/// UserInfoVCCredential
///
UserInfoVCCredential::UserInfoVCCredential(std::string userinfo_vc_jwt_in)
: userinfo_vc_jwt(std::move(userinfo_vc_jwt_in))
, _vc(std::make_shared<hpke::UserInfoVC>(userinfo_vc_jwt))
{
}
bool
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
UserInfoVCCredential::valid_for(const SignaturePublicKey& pub) const
{
const auto& vc_pub = _vc->public_key();
return pub.data == vc_pub.sig.serialize(*vc_pub.key);
}
bool
UserInfoVCCredential::valid_from(const PublicJWK& pub) const
{
const auto& sig = _vc->signature_algorithm();
if (pub.signature_scheme != tls_signature_scheme(sig.id)) {
return false;
}
const auto sig_pub = sig.deserialize(pub.public_key.data);
return _vc->valid_from(*sig_pub);
}
tls::ostream
operator<<(tls::ostream& str, const UserInfoVCCredential& obj)
{
return str << from_ascii(obj.userinfo_vc_jwt);
}
tls::istream
operator>>(tls::istream& str, UserInfoVCCredential& obj)
{
auto jwt = bytes{};
str >> jwt;
obj = UserInfoVCCredential(to_ascii(jwt));
return str;
}
bool
operator==(const UserInfoVCCredential& lhs, const UserInfoVCCredential& rhs)
{
return lhs.userinfo_vc_jwt == rhs.userinfo_vc_jwt;
}
bool
operator!=(const UserInfoVCCredential& lhs, const UserInfoVCCredential& rhs)
{
return !(lhs == rhs);
}
///
/// CredentialBinding and MultiCredential
///
struct CredentialBindingTBS
{
const CipherSuite& cipher_suite;
const Credential& credential;
const SignaturePublicKey& credential_key;
const SignaturePublicKey& signature_key;
TLS_SERIALIZABLE(cipher_suite, credential, credential_key, signature_key)
};
CredentialBinding::CredentialBinding(CipherSuite cipher_suite_in,
Credential credential_in,
const SignaturePrivateKey& credential_priv,
const SignaturePublicKey& signature_key)
: cipher_suite(cipher_suite_in)
, credential(std::move(credential_in))
, credential_key(credential_priv.public_key)
{
if (credential.type() == CredentialType::multi_draft_00) {
throw InvalidParameterError("Multi-credentials cannot be nested");
}
if (!credential.valid_for(credential_key)) {
throw InvalidParameterError("Credential key does not match credential");
}
signature = credential_priv.sign(
cipher_suite, sign_label::multi_credential, to_be_signed(signature_key));
}
bytes
CredentialBinding::to_be_signed(const SignaturePublicKey& signature_key) const
{
return tls::marshal(CredentialBindingTBS{
cipher_suite, credential, credential_key, signature_key });
}
bool
CredentialBinding::valid_for(const SignaturePublicKey& signature_key) const
{
auto valid_self = credential.valid_for(credential_key);
auto valid_other = credential_key.verify(cipher_suite,
sign_label::multi_credential,
to_be_signed(signature_key),
signature);
return valid_self && valid_other;
}
MultiCredential::MultiCredential(
const std::vector<CredentialBindingInput>& binding_inputs,
const SignaturePublicKey& signature_key)
{
bindings =
stdx::transform<CredentialBinding>(binding_inputs, [&](auto&& input) {
return CredentialBinding(input.cipher_suite,
input.credential,
input.credential_priv,
signature_key);
});
}
bool
MultiCredential::valid_for(const SignaturePublicKey& pub) const
{
return stdx::all_of(
bindings, [&](const auto& binding) { return binding.valid_for(pub); });
}
///
/// Credential
///
CredentialType
Credential::type() const
{
return tls::variant<CredentialType>::type(_cred);
}
Credential
Credential::basic(const bytes& identity)
{
return { BasicCredential{ identity } };
}
Credential
Credential::x509(const std::vector<bytes>& der_chain)
{
return { X509Credential{ der_chain } };
}
Credential
Credential::multi(const std::vector<CredentialBindingInput>& binding_inputs,
const SignaturePublicKey& signature_key)
{
return { MultiCredential{ binding_inputs, signature_key } };
}
Credential
Credential::userinfo_vc(const std::string& userinfo_vc_jwt)
{
return { UserInfoVCCredential{ userinfo_vc_jwt } };
}
bool
Credential::valid_for(const SignaturePublicKey& pub) const
{
const auto pub_key_match = overloaded{
[&](const X509Credential& x509) { return x509.valid_for(pub); },
[](const BasicCredential& /* basic */) { return true; },
[&](const UserInfoVCCredential& vc) { return vc.valid_for(pub); },
[&](const MultiCredential& multi) { return multi.valid_for(pub); },
};
return var::visit(pub_key_match, _cred);
}
Credential::Credential(SpecificCredential specific)
: _cred(std::move(specific))
{
}
} // namespace mlspp

498
DPP/mlspp/src/crypto.cpp Executable file
View File

@@ -0,0 +1,498 @@
#include <mls/core_types.h>
#include <mls/crypto.h>
#include <mls/messages.h>
#include <string>
using mlspp::hpke::AEAD; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::Digest; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::HPKE; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::KDF; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::KEM; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::Signature; // NOLINT(misc-unused-using-decls)
namespace mlspp {
SignatureScheme
tls_signature_scheme(Signature::ID id)
{
switch (id) {
case Signature::ID::P256_SHA256:
return SignatureScheme::ecdsa_secp256r1_sha256;
case Signature::ID::P384_SHA384:
return SignatureScheme::ecdsa_secp384r1_sha384;
case Signature::ID::P521_SHA512:
return SignatureScheme::ecdsa_secp521r1_sha512;
case Signature::ID::Ed25519:
return SignatureScheme::ed25519;
#if !defined(WITH_BORINGSSL)
case Signature::ID::Ed448:
return SignatureScheme::ed448;
#endif
case Signature::ID::RSA_SHA256:
return SignatureScheme::rsa_pkcs1_sha256;
default:
throw InvalidParameterError("Unsupported algorithm");
}
}
///
/// CipherSuites and details
///
CipherSuite::CipherSuite()
: id(ID::unknown)
{
}
CipherSuite::CipherSuite(ID id_in)
: id(id_in)
{
}
SignatureScheme
CipherSuite::signature_scheme() const
{
switch (id) {
case ID::X25519_AES128GCM_SHA256_Ed25519:
case ID::X25519_CHACHA20POLY1305_SHA256_Ed25519:
return SignatureScheme::ed25519;
case ID::P256_AES128GCM_SHA256_P256:
return SignatureScheme::ecdsa_secp256r1_sha256;
case ID::X448_AES256GCM_SHA512_Ed448:
case ID::X448_CHACHA20POLY1305_SHA512_Ed448:
return SignatureScheme::ed448;
case ID::P521_AES256GCM_SHA512_P521:
return SignatureScheme::ecdsa_secp521r1_sha512;
case ID::P384_AES256GCM_SHA384_P384:
return SignatureScheme::ecdsa_secp384r1_sha384;
default:
throw InvalidParameterError("Unsupported algorithm");
}
}
const CipherSuite::Ciphers&
CipherSuite::get() const
{
static const auto ciphers_X25519_AES128GCM_SHA256_Ed25519 =
CipherSuite::Ciphers{
HPKE(KEM::ID::DHKEM_X25519_SHA256,
KDF::ID::HKDF_SHA256,
AEAD::ID::AES_128_GCM),
Digest::get<Digest::ID::SHA256>(),
Signature::get<Signature::ID::Ed25519>(),
};
static const auto ciphers_P256_AES128GCM_SHA256_P256 = CipherSuite::Ciphers{
HPKE(
KEM::ID::DHKEM_P256_SHA256, KDF::ID::HKDF_SHA256, AEAD::ID::AES_128_GCM),
Digest::get<Digest::ID::SHA256>(),
Signature::get<Signature::ID::P256_SHA256>(),
};
static const auto ciphers_X25519_CHACHA20POLY1305_SHA256_Ed25519 =
CipherSuite::Ciphers{
HPKE(KEM::ID::DHKEM_X25519_SHA256,
KDF::ID::HKDF_SHA256,
AEAD::ID::CHACHA20_POLY1305),
Digest::get<Digest::ID::SHA256>(),
Signature::get<Signature::ID::Ed25519>(),
};
static const auto ciphers_P521_AES256GCM_SHA512_P521 = CipherSuite::Ciphers{
HPKE(
KEM::ID::DHKEM_P521_SHA512, KDF::ID::HKDF_SHA512, AEAD::ID::AES_256_GCM),
Digest::get<Digest::ID::SHA512>(),
Signature::get<Signature::ID::P521_SHA512>(),
};
static const auto ciphers_P384_AES256GCM_SHA384_P384 = CipherSuite::Ciphers{
HPKE(
KEM::ID::DHKEM_P384_SHA384, KDF::ID::HKDF_SHA384, AEAD::ID::AES_256_GCM),
Digest::get<Digest::ID::SHA384>(),
Signature::get<Signature::ID::P384_SHA384>(),
};
#if !defined(WITH_BORINGSSL)
static const auto ciphers_X448_AES256GCM_SHA512_Ed448 = CipherSuite::Ciphers{
HPKE(
KEM::ID::DHKEM_X448_SHA512, KDF::ID::HKDF_SHA512, AEAD::ID::AES_256_GCM),
Digest::get<Digest::ID::SHA512>(),
Signature::get<Signature::ID::Ed448>(),
};
static const auto ciphers_X448_CHACHA20POLY1305_SHA512_Ed448 =
CipherSuite::Ciphers{
HPKE(KEM::ID::DHKEM_X448_SHA512,
KDF::ID::HKDF_SHA512,
AEAD::ID::CHACHA20_POLY1305),
Digest::get<Digest::ID::SHA512>(),
Signature::get<Signature::ID::Ed448>(),
};
#endif
switch (id) {
case ID::unknown:
throw InvalidParameterError("Uninitialized ciphersuite");
case ID::X25519_AES128GCM_SHA256_Ed25519:
return ciphers_X25519_AES128GCM_SHA256_Ed25519;
case ID::P256_AES128GCM_SHA256_P256:
return ciphers_P256_AES128GCM_SHA256_P256;
case ID::X25519_CHACHA20POLY1305_SHA256_Ed25519:
return ciphers_X25519_CHACHA20POLY1305_SHA256_Ed25519;
case ID::P521_AES256GCM_SHA512_P521:
return ciphers_P521_AES256GCM_SHA512_P521;
case ID::P384_AES256GCM_SHA384_P384:
return ciphers_P384_AES256GCM_SHA384_P384;
#if !defined(WITH_BORINGSSL)
case ID::X448_AES256GCM_SHA512_Ed448:
return ciphers_X448_AES256GCM_SHA512_Ed448;
case ID::X448_CHACHA20POLY1305_SHA512_Ed448:
return ciphers_X448_CHACHA20POLY1305_SHA512_Ed448;
#endif
default:
throw InvalidParameterError("Unsupported ciphersuite");
}
}
struct HKDFLabel
{
uint16_t length;
bytes label;
bytes context;
TLS_SERIALIZABLE(length, label, context)
};
bytes
CipherSuite::expand_with_label(const bytes& secret,
const std::string& label,
const bytes& context,
size_t length) const
{
auto mls_label = from_ascii(std::string("MLS 1.0 ") + label);
auto length16 = static_cast<uint16_t>(length);
auto label_bytes = tls::marshal(HKDFLabel{ length16, mls_label, context });
return get().hpke.kdf.expand(secret, label_bytes, length);
}
bytes
CipherSuite::derive_secret(const bytes& secret, const std::string& label) const
{
return expand_with_label(secret, label, {}, secret_size());
}
bytes
CipherSuite::derive_tree_secret(const bytes& secret,
const std::string& label,
uint32_t generation,
size_t length) const
{
return expand_with_label(secret, label, tls::marshal(generation), length);
}
#if WITH_BORINGSSL
const std::array<CipherSuite::ID, 5> all_supported_suites = {
CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519,
CipherSuite::ID::P256_AES128GCM_SHA256_P256,
CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519,
CipherSuite::ID::P521_AES256GCM_SHA512_P521,
CipherSuite::ID::P384_AES256GCM_SHA384_P384,
};
#else
const std::array<CipherSuite::ID, 7> all_supported_suites = {
CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519,
CipherSuite::ID::P256_AES128GCM_SHA256_P256,
CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519,
CipherSuite::ID::P521_AES256GCM_SHA512_P521,
CipherSuite::ID::P384_AES256GCM_SHA384_P384,
CipherSuite::ID::X448_CHACHA20POLY1305_SHA512_Ed448,
CipherSuite::ID::X448_AES256GCM_SHA512_Ed448,
};
#endif
// MakeKeyPackageRef(value) = KDF.expand(
// KDF.extract("", value), "MLS 1.0 KeyPackage Reference", 16)
template<>
const bytes&
CipherSuite::reference_label<KeyPackage>()
{
static const auto label = from_ascii("MLS 1.0 KeyPackage Reference");
return label;
}
// MakeProposalRef(value) = KDF.expand(
// KDF.extract("", value), "MLS 1.0 Proposal Reference", 16)
//
// Even though the label says "Proposal", we actually hash the entire enclosing
// AuthenticatedContent object.
template<>
const bytes&
CipherSuite::reference_label<AuthenticatedContent>()
{
static const auto label = from_ascii("MLS 1.0 Proposal Reference");
return label;
}
///
/// HPKEPublicKey and HPKEPrivateKey
///
// This function produces a non-literal type, so it can't be constexpr.
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define MLS_1_0_PLUS(label) from_ascii("MLS 1.0 " label)
static bytes
mls_1_0_plus(const std::string& label)
{
auto plus = "MLS 1.0 "s + label;
return from_ascii(plus);
}
namespace encrypt_label {
const std::string update_path_node = "UpdatePathNode";
const std::string welcome = "Welcome";
} // namespace encrypt_label
struct EncryptContext
{
const bytes& label;
const bytes& content;
TLS_SERIALIZABLE(label, content)
};
HPKECiphertext
HPKEPublicKey::encrypt(CipherSuite suite,
const std::string& label,
const bytes& context,
const bytes& pt) const
{
auto label_plus = mls_1_0_plus(label);
auto encrypt_context = tls::marshal(EncryptContext{ label_plus, context });
auto pkR = suite.hpke().kem.deserialize(data);
auto [enc, ctx] = suite.hpke().setup_base_s(*pkR, encrypt_context);
auto ct = ctx.seal({}, pt);
return HPKECiphertext{ enc, ct };
}
std::tuple<bytes, bytes>
HPKEPublicKey::do_export(CipherSuite suite,
const bytes& info,
const std::string& label,
size_t size) const
{
auto label_data = from_ascii(label);
auto pkR = suite.hpke().kem.deserialize(data);
auto [enc, ctx] = suite.hpke().setup_base_s(*pkR, info);
auto exported = ctx.do_export(label_data, size);
return std::make_tuple(enc, exported);
}
HPKEPrivateKey
HPKEPrivateKey::generate(CipherSuite suite)
{
auto priv = suite.hpke().kem.generate_key_pair();
auto priv_data = suite.hpke().kem.serialize_private(*priv);
auto pub = priv->public_key();
auto pub_data = suite.hpke().kem.serialize(*pub);
return { priv_data, pub_data };
}
HPKEPrivateKey
HPKEPrivateKey::parse(CipherSuite suite, const bytes& data)
{
auto priv = suite.hpke().kem.deserialize_private(data);
auto pub = priv->public_key();
auto pub_data = suite.hpke().kem.serialize(*pub);
return { data, pub_data };
}
HPKEPrivateKey
HPKEPrivateKey::derive(CipherSuite suite, const bytes& secret)
{
auto priv = suite.hpke().kem.derive_key_pair(secret);
auto priv_data = suite.hpke().kem.serialize_private(*priv);
auto pub = priv->public_key();
auto pub_data = suite.hpke().kem.serialize(*pub);
return { priv_data, pub_data };
}
bytes
HPKEPrivateKey::decrypt(CipherSuite suite,
const std::string& label,
const bytes& context,
const HPKECiphertext& ct) const
{
auto label_plus = mls_1_0_plus(label);
auto encrypt_context = tls::marshal(EncryptContext{ label_plus, context });
auto skR = suite.hpke().kem.deserialize_private(data);
auto ctx = suite.hpke().setup_base_r(ct.kem_output, *skR, encrypt_context);
auto pt = ctx.open({}, ct.ciphertext);
if (!pt) {
throw InvalidParameterError("HPKE decryption failure");
}
return opt::get(pt);
}
bytes
HPKEPrivateKey::do_export(CipherSuite suite,
const bytes& info,
const bytes& kem_output,
const std::string& label,
size_t size) const
{
auto label_data = from_ascii(label);
auto skR = suite.hpke().kem.deserialize_private(data);
auto ctx = suite.hpke().setup_base_r(kem_output, *skR, info);
return ctx.do_export(label_data, size);
}
HPKEPrivateKey::HPKEPrivateKey(bytes priv_data, bytes pub_data)
: data(std::move(priv_data))
, public_key{ std::move(pub_data) }
{
}
void
HPKEPrivateKey::set_public_key(CipherSuite suite)
{
const auto priv = suite.hpke().kem.deserialize_private(data);
auto pub = priv->public_key();
public_key.data = suite.hpke().kem.serialize(*pub);
}
///
/// SignaturePublicKey and SignaturePrivateKey
///
namespace sign_label {
const std::string mls_content = "FramedContentTBS";
const std::string leaf_node = "LeafNodeTBS";
const std::string key_package = "KeyPackageTBS";
const std::string group_info = "GroupInfoTBS";
const std::string multi_credential = "MultiCredential";
} // namespace sign_label
struct SignContent
{
const bytes& label;
const bytes& content;
TLS_SERIALIZABLE(label, content)
};
bool
SignaturePublicKey::verify(const CipherSuite& suite,
const std::string& label,
const bytes& message,
const bytes& signature) const
{
auto label_plus = mls_1_0_plus(label);
const auto content = tls::marshal(SignContent{ label_plus, message });
auto pub = suite.sig().deserialize(data);
return suite.sig().verify(content, signature, *pub);
}
SignaturePublicKey
SignaturePublicKey::from_jwk(CipherSuite suite, const std::string& json_str)
{
auto pub = suite.sig().import_jwk(json_str);
auto pub_data = suite.sig().serialize(*pub);
return SignaturePublicKey{ pub_data };
}
std::string
SignaturePublicKey::to_jwk(CipherSuite suite) const
{
auto pub = suite.sig().deserialize(data);
return suite.sig().export_jwk(*pub);
}
PublicJWK
PublicJWK::parse(const std::string& jwk_json)
{
const auto parsed = Signature::parse_jwk(jwk_json);
const auto scheme = tls_signature_scheme(parsed.sig.id);
const auto pub_data = parsed.sig.serialize(*parsed.key);
return { scheme, parsed.key_id, { pub_data } };
}
SignaturePrivateKey
SignaturePrivateKey::generate(CipherSuite suite)
{
auto priv = suite.sig().generate_key_pair();
auto priv_data = suite.sig().serialize_private(*priv);
auto pub = priv->public_key();
auto pub_data = suite.sig().serialize(*pub);
return { priv_data, pub_data };
}
SignaturePrivateKey
SignaturePrivateKey::parse(CipherSuite suite, const bytes& data)
{
auto priv = suite.sig().deserialize_private(data);
auto pub = priv->public_key();
auto pub_data = suite.sig().serialize(*pub);
return { data, pub_data };
}
SignaturePrivateKey
SignaturePrivateKey::derive(CipherSuite suite, const bytes& secret)
{
auto priv = suite.sig().derive_key_pair(secret);
auto priv_data = suite.sig().serialize_private(*priv);
auto pub = priv->public_key();
auto pub_data = suite.sig().serialize(*pub);
return { priv_data, pub_data };
}
bytes
SignaturePrivateKey::sign(const CipherSuite& suite,
const std::string& label,
const bytes& message) const
{
auto label_plus = mls_1_0_plus(label);
const auto content = tls::marshal(SignContent{ label_plus, message });
const auto priv = suite.sig().deserialize_private(data);
return suite.sig().sign(content, *priv);
}
SignaturePrivateKey::SignaturePrivateKey(bytes priv_data, bytes pub_data)
: data(std::move(priv_data))
, public_key{ std::move(pub_data) }
{
}
void
SignaturePrivateKey::set_public_key(CipherSuite suite)
{
const auto priv = suite.sig().deserialize_private(data);
auto pub = priv->public_key();
public_key.data = suite.sig().serialize(*pub);
}
SignaturePrivateKey
SignaturePrivateKey::from_jwk(CipherSuite suite, const std::string& json_str)
{
auto priv = suite.sig().import_jwk_private(json_str);
auto priv_data = suite.sig().serialize_private(*priv);
auto pub = priv->public_key();
auto pub_data = suite.sig().serialize(*pub);
return { priv_data, pub_data };
}
std::string
SignaturePrivateKey::to_jwk(CipherSuite suite) const
{
const auto priv = suite.sig().deserialize_private(data);
return suite.sig().export_jwk_private(*priv);
}
} // namespace mlspp

126
DPP/mlspp/src/grease.cpp Executable file
View File

@@ -0,0 +1,126 @@
#include "grease.h"
#include <random>
#include <set>
namespace mlspp {
#ifdef DISABLE_GREASE
void
grease([[maybe_unused]] Capabilities& capabilities,
[[maybe_unused]] const ExtensionList& extensions)
{
}
void
grease([[maybe_unused]] ExtensionList& extensions)
{
}
#else
// Randomness parmeters:
// * Given a list of N items, insert max(1, rand(p_grease * N)) GREASE values
// * Each GREASE value added is distinct, unless more than 15 values are needed
// * For extensions, each GREASE extension has rand(n_grease_ext) random bytes
// of data
const size_t log_p_grease = 1; // -log2(p_grease) => p_grease = 1/2
const size_t max_grease_ext_size = 16;
const std::array<uint16_t, 15> grease_values = { 0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A,
0x4A4A, 0x5A5A, 0x6A6A, 0x7A7A,
0x8A8A, 0x9A9A, 0xAAAA, 0xBABA,
0xCACA, 0xDADA, 0xEAEA };
static size_t
rand_int(size_t n)
{
static auto seed = std::random_device()();
static auto rng = std::mt19937(seed);
return std::uniform_int_distribution<size_t>(0, n)(rng);
}
static uint16_t
grease_value()
{
const auto where = rand_int(grease_values.size() - 1);
return grease_values.at(where);
}
static bool
grease_value(uint16_t val)
{
static constexpr auto grease_mask = uint16_t(0x0F0F);
return ((val & grease_mask) == 0x0A0A) && val != 0xFAFA;
}
static std::set<uint16_t>
grease_sample(size_t count)
{
auto vals = std::set<uint16_t>{};
while (vals.size() < count) {
uint16_t val = grease_value();
while (vals.count(val) > 0 && vals.size() < grease_values.size()) {
val = grease_value();
}
vals.insert(val);
}
return vals;
}
template<typename T>
static void
grease(std::vector<T>& vec)
{
const auto count = std::max(size_t(1), rand_int(vec.size() >> log_p_grease));
for (const auto val : grease_sample(count)) {
const auto where = static_cast<ptrdiff_t>(rand_int(vec.size()));
vec.insert(std::begin(vec) + where, static_cast<T>(val));
}
}
void
grease(Capabilities& capabilities, const ExtensionList& extensions)
{
// Add GREASE to the appropriate portions of the capabilities
grease(capabilities.cipher_suites);
grease(capabilities.extensions);
grease(capabilities.proposals);
grease(capabilities.credentials);
// Ensure that the GREASE extensions are reflected in Capabilities.extensions
for (const auto& ext : extensions.extensions) {
if (!grease_value(ext.type)) {
continue;
}
if (stdx::contains(capabilities.extensions, ext.type)) {
continue;
}
const auto where =
static_cast<ptrdiff_t>(rand_int(capabilities.extensions.size()));
const auto where_ptr = std::begin(capabilities.extensions) + where;
capabilities.extensions.insert(where_ptr, ext.type);
}
}
void
grease(ExtensionList& extensions)
{
auto& ext = extensions.extensions;
const auto count = std::max(size_t(1), rand_int(ext.size() >> log_p_grease));
for (const auto ext_type : grease_sample(count)) {
const auto where = static_cast<ptrdiff_t>(rand_int(ext.size()));
auto ext_data = random_bytes(rand_int(max_grease_ext_size));
ext.insert(std::begin(ext) + where, { ext_type, std::move(ext_data) });
}
}
#endif // DISABLE_GREASE
} // namespace mlspp

13
DPP/mlspp/src/grease.h Executable file
View File

@@ -0,0 +1,13 @@
#pragma once
#include "mls/core_types.h"
namespace mlspp {
void
grease(Capabilities& capabilities, const ExtensionList& extensions);
void
grease(ExtensionList& extensions);
} // namespace mlspp

579
DPP/mlspp/src/key_schedule.cpp Executable file
View File

@@ -0,0 +1,579 @@
#include <mls/key_schedule.h>
namespace mlspp {
///
/// Key Derivation Functions
///
struct TreeContext
{
NodeIndex node;
uint32_t generation = 0;
TLS_SERIALIZABLE(node, generation)
};
///
/// HashRatchet
///
HashRatchet::HashRatchet(CipherSuite suite_in, bytes base_secret_in)
: suite(suite_in)
, next_secret(std::move(base_secret_in))
, next_generation(0)
, key_size(suite.hpke().aead.key_size)
, nonce_size(suite.hpke().aead.nonce_size)
, secret_size(suite.secret_size())
{
}
std::tuple<uint32_t, KeyAndNonce>
HashRatchet::next()
{
auto generation = next_generation;
auto key = suite.derive_tree_secret(next_secret, "key", generation, key_size);
auto nonce =
suite.derive_tree_secret(next_secret, "nonce", generation, nonce_size);
auto secret =
suite.derive_tree_secret(next_secret, "secret", generation, secret_size);
next_generation += 1;
next_secret = secret;
cache[generation] = { key, nonce };
return { generation, cache.at(generation) };
}
// Note: This construction deliberately does not preserve the forward-secrecy
// invariant, in that keys/nonces are not deleted after they are used.
// Otherwise, it would not be possible for a node to send to itself. Keys can
// be deleted once they are not needed by calling HashRatchet::erase().
KeyAndNonce
HashRatchet::get(uint32_t generation)
{
if (cache.count(generation) > 0) {
auto out = cache.at(generation);
return out;
}
if (next_generation > generation) {
throw ProtocolError("Request for expired key");
}
while (next_generation <= generation) {
next();
}
return cache.at(generation);
}
void
HashRatchet::erase(uint32_t generation)
{
if (cache.count(generation) == 0) {
return;
}
cache.erase(generation);
}
///
/// SecretTree
///
SecretTree::SecretTree(CipherSuite suite_in,
LeafCount group_size_in,
bytes encryption_secret_in)
: suite(suite_in)
, group_size(LeafCount::full(group_size_in))
, root(NodeIndex::root(group_size))
, secret_size(suite_in.secret_size())
{
secrets.emplace(root, std::move(encryption_secret_in));
}
bytes
SecretTree::get(LeafIndex sender)
{
static const auto context_left = from_ascii("left");
static const auto context_right = from_ascii("right");
auto node = NodeIndex(sender);
// Find an ancestor that is populated
auto dirpath = node.dirpath(group_size);
dirpath.insert(dirpath.begin(), node);
dirpath.push_back(root);
uint32_t curr = 0;
for (; curr < dirpath.size(); ++curr) {
auto i = dirpath.at(curr);
if (secrets.count(i) > 0) {
break;
}
}
if (curr > dirpath.size()) {
throw InvalidParameterError("No secret found to derive base key");
}
// Derive down
for (; curr > 0; --curr) {
auto curr_node = dirpath.at(curr);
auto left = curr_node.left();
auto right = curr_node.right();
auto& secret = secrets.at(curr_node);
const auto left_secret =
suite.expand_with_label(secret, "tree", context_left, secret_size);
const auto right_secret =
suite.expand_with_label(secret, "tree", context_right, secret_size);
secrets.insert_or_assign(left, left_secret);
secrets.insert_or_assign(right, right_secret);
}
// Copy the leaf
auto out = secrets.at(node);
// Zeroize along the direct path
for (auto i : dirpath) {
secrets.erase(i);
}
return out;
}
///
/// ReuseGuard
///
static ReuseGuard
new_reuse_guard()
{
auto random = random_bytes(4);
auto guard = ReuseGuard();
std::copy(random.begin(), random.end(), guard.begin());
return guard;
}
static void
apply_reuse_guard(const ReuseGuard& guard, bytes& nonce)
{
for (size_t i = 0; i < guard.size(); i++) {
nonce.at(i) ^= guard.at(i);
}
}
///
/// GroupKeySource
///
GroupKeySource::GroupKeySource(CipherSuite suite_in,
LeafCount group_size,
bytes encryption_secret)
: suite(suite_in)
, secret_tree(suite, group_size, std::move(encryption_secret))
{
}
HashRatchet&
GroupKeySource::chain(ContentType type, LeafIndex sender)
{
switch (type) {
case ContentType::proposal:
case ContentType::commit:
return chain(RatchetType::handshake, sender);
case ContentType::application:
return chain(RatchetType::application, sender);
default:
throw InvalidParameterError("Invalid content type");
}
}
HashRatchet&
GroupKeySource::chain(RatchetType type, LeafIndex sender)
{
auto key = Key{ type, sender };
if (chains.count(key) > 0) {
return chains[key];
}
auto secret_size = suite.secret_size();
auto leaf_secret = secret_tree.get(sender);
auto handshake_secret =
suite.expand_with_label(leaf_secret, "handshake", {}, secret_size);
auto application_secret =
suite.expand_with_label(leaf_secret, "application", {}, secret_size);
chains.emplace(Key{ RatchetType::handshake, sender },
HashRatchet{ suite, handshake_secret });
chains.emplace(Key{ RatchetType::application, sender },
HashRatchet{ suite, application_secret });
return chains[key];
}
std::tuple<uint32_t, ReuseGuard, KeyAndNonce>
GroupKeySource::next(ContentType type, LeafIndex sender)
{
auto [generation, keys] = chain(type, sender).next();
auto reuse_guard = new_reuse_guard();
apply_reuse_guard(reuse_guard, keys.nonce);
return { generation, reuse_guard, keys };
}
KeyAndNonce
GroupKeySource::get(ContentType type,
LeafIndex sender,
uint32_t generation,
ReuseGuard reuse_guard)
{
auto keys = chain(type, sender).get(generation);
apply_reuse_guard(reuse_guard, keys.nonce);
return keys;
}
void
GroupKeySource::erase(ContentType type, LeafIndex sender, uint32_t generation)
{
return chain(type, sender).erase(generation);
}
// struct {
// opaque group_id<0..255>;
// uint64 epoch;
// ContentType content_type;
// opaque authenticated_data<0..2^32-1>;
// } ContentAAD;
struct ContentAAD
{
const bytes& group_id;
const epoch_t epoch;
const ContentType content_type;
const bytes& authenticated_data;
TLS_SERIALIZABLE(group_id, epoch, content_type, authenticated_data)
};
///
/// KeyScheduleEpoch
///
struct PSKLabel
{
const PreSharedKeyID& id;
uint16_t index;
uint16_t count;
TLS_SERIALIZABLE(id, index, count);
};
static bytes
make_joiner_secret(CipherSuite suite,
const bytes& context,
const bytes& init_secret,
const bytes& commit_secret)
{
auto pre_joiner_secret = suite.hpke().kdf.extract(init_secret, commit_secret);
return suite.expand_with_label(
pre_joiner_secret, "joiner", context, suite.secret_size());
}
static bytes
make_epoch_secret(CipherSuite suite,
const bytes& joiner_secret,
const bytes& psk_secret,
const bytes& context)
{
auto member_secret = suite.hpke().kdf.extract(joiner_secret, psk_secret);
return suite.expand_with_label(
member_secret, "epoch", context, suite.secret_size());
}
KeyScheduleEpoch
KeyScheduleEpoch::joiner(CipherSuite suite_in,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks,
const bytes& context)
{
return { suite_in, joiner_secret, make_psk_secret(suite_in, psks), context };
}
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
const bytes& joiner_secret,
const bytes& psk_secret,
const bytes& context)
: suite(suite_in)
, joiner_secret(joiner_secret)
, epoch_secret(
make_epoch_secret(suite_in, joiner_secret, psk_secret, context))
, sender_data_secret(suite.derive_secret(epoch_secret, "sender data"))
, encryption_secret(suite.derive_secret(epoch_secret, "encryption"))
, exporter_secret(suite.derive_secret(epoch_secret, "exporter"))
, epoch_authenticator(suite.derive_secret(epoch_secret, "authentication"))
, external_secret(suite.derive_secret(epoch_secret, "external"))
, confirmation_key(suite.derive_secret(epoch_secret, "confirm"))
, membership_key(suite.derive_secret(epoch_secret, "membership"))
, resumption_psk(suite.derive_secret(epoch_secret, "resumption"))
, init_secret(suite.derive_secret(epoch_secret, "init"))
, external_priv(HPKEPrivateKey::derive(suite, external_secret))
{
}
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in)
: suite(suite_in)
{
}
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
const bytes& init_secret,
const bytes& context)
: KeyScheduleEpoch(
suite_in,
make_joiner_secret(suite_in, context, init_secret, suite_in.zero()),
{ /* no PSKs */ },
context)
{
}
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
const bytes& init_secret,
const bytes& commit_secret,
const bytes& psk_secret,
const bytes& context)
: KeyScheduleEpoch(
suite_in,
make_joiner_secret(suite_in, context, init_secret, commit_secret),
psk_secret,
context)
{
}
std::tuple<bytes, bytes>
KeyScheduleEpoch::external_init(CipherSuite suite,
const HPKEPublicKey& external_pub)
{
auto size = suite.secret_size();
return external_pub.do_export(
suite, {}, "MLS 1.0 external init secret", size);
}
bytes
KeyScheduleEpoch::receive_external_init(const bytes& kem_output) const
{
auto size = suite.secret_size();
return external_priv.do_export(
suite, {}, kem_output, "MLS 1.0 external init secret", size);
}
KeyScheduleEpoch
KeyScheduleEpoch::next(const bytes& commit_secret,
const std::vector<PSKWithSecret>& psks,
const std::optional<bytes>& force_init_secret,
const bytes& context) const
{
return next_raw(
commit_secret, make_psk_secret(suite, psks), force_init_secret, context);
}
KeyScheduleEpoch
KeyScheduleEpoch::next_raw(const bytes& commit_secret,
const bytes& psk_secret,
const std::optional<bytes>& force_init_secret,
const bytes& context) const
{
auto actual_init_secret = init_secret;
if (force_init_secret) {
actual_init_secret = opt::get(force_init_secret);
}
return { suite, actual_init_secret, commit_secret, psk_secret, context };
}
GroupKeySource
KeyScheduleEpoch::encryption_keys(LeafCount size) const
{
return { suite, size, encryption_secret };
}
bytes
KeyScheduleEpoch::confirmation_tag(const bytes& confirmed_transcript_hash) const
{
return suite.digest().hmac(confirmation_key, confirmed_transcript_hash);
}
bytes
KeyScheduleEpoch::do_export(const std::string& label,
const bytes& context,
size_t size) const
{
auto secret = suite.derive_secret(exporter_secret, label);
auto context_hash = suite.digest().hash(context);
return suite.expand_with_label(secret, "exported", context_hash, size);
}
PSKWithSecret
KeyScheduleEpoch::resumption_psk_w_secret(ResumptionPSKUsage usage,
const bytes& group_id,
epoch_t epoch)
{
auto nonce = random_bytes(suite.secret_size());
auto psk = ResumptionPSK{ usage, group_id, epoch };
return { { psk, nonce }, resumption_psk };
}
bytes
KeyScheduleEpoch::make_psk_secret(CipherSuite suite,
const std::vector<PSKWithSecret>& psks)
{
auto psk_secret = suite.zero();
auto count = uint16_t(psks.size());
auto index = uint16_t(0);
for (const auto& psk : psks) {
auto psk_extracted = suite.hpke().kdf.extract(suite.zero(), psk.secret);
auto psk_label = tls::marshal(PSKLabel{ psk.id, index, count });
auto psk_input = suite.expand_with_label(
psk_extracted, "derived psk", psk_label, suite.secret_size());
psk_secret = suite.hpke().kdf.extract(psk_input, psk_secret);
index += 1;
}
return psk_secret;
}
bytes
KeyScheduleEpoch::welcome_secret(CipherSuite suite,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks)
{
auto psk_secret = make_psk_secret(suite, psks);
return welcome_secret_raw(suite, joiner_secret, psk_secret);
}
bytes
KeyScheduleEpoch::welcome_secret_raw(CipherSuite suite,
const bytes& joiner_secret,
const bytes& psk_secret)
{
auto extract = suite.hpke().kdf.extract(joiner_secret, psk_secret);
return suite.derive_secret(extract, "welcome");
}
KeyAndNonce
KeyScheduleEpoch::sender_data_keys(CipherSuite suite,
const bytes& sender_data_secret,
const bytes& ciphertext)
{
auto sample_size = suite.secret_size();
auto sample = bytes(sample_size);
if (ciphertext.size() <= sample_size) {
sample = ciphertext;
} else {
sample = ciphertext.slice(0, sample_size);
}
auto key_size = suite.hpke().aead.key_size;
auto nonce_size = suite.hpke().aead.nonce_size;
return {
suite.expand_with_label(sender_data_secret, "key", sample, key_size),
suite.expand_with_label(sender_data_secret, "nonce", sample, nonce_size),
};
}
bool
operator==(const KeyScheduleEpoch& lhs, const KeyScheduleEpoch& rhs)
{
auto epoch_secret = (lhs.epoch_secret == rhs.epoch_secret);
auto sender_data_secret = (lhs.sender_data_secret == rhs.sender_data_secret);
auto encryption_secret = (lhs.encryption_secret == rhs.encryption_secret);
auto exporter_secret = (lhs.exporter_secret == rhs.exporter_secret);
auto confirmation_key = (lhs.confirmation_key == rhs.confirmation_key);
auto init_secret = (lhs.init_secret == rhs.init_secret);
auto external_priv = (lhs.external_priv == rhs.external_priv);
return epoch_secret && sender_data_secret && encryption_secret &&
exporter_secret && confirmation_key && init_secret && external_priv;
}
// struct {
// WireFormat wire_format;
// GroupContent content; // with content.content_type == commit
// opaque signature<V>;
// } ConfirmedTranscriptHashInput;
struct ConfirmedTranscriptHashInput
{
WireFormat wire_format;
const GroupContent& content;
const bytes& signature;
TLS_SERIALIZABLE(wire_format, content, signature)
};
// struct {
// MAC confirmation_tag;
// } InterimTranscriptHashInput;
struct InterimTranscriptHashInput
{
bytes confirmation_tag;
TLS_SERIALIZABLE(confirmation_tag)
};
TranscriptHash::TranscriptHash(CipherSuite suite_in)
: suite(suite_in)
{
}
TranscriptHash::TranscriptHash(CipherSuite suite_in,
bytes confirmed_in,
const bytes& confirmation_tag)
: suite(suite_in)
, confirmed(std::move(confirmed_in))
{
update_interim(confirmation_tag);
}
void
TranscriptHash::update(const AuthenticatedContent& content_auth)
{
update_confirmed(content_auth);
update_interim(content_auth);
}
void
TranscriptHash::update_confirmed(const AuthenticatedContent& content_auth)
{
const auto transcript =
interim + content_auth.confirmed_transcript_hash_input();
confirmed = suite.digest().hash(transcript);
}
void
TranscriptHash::update_interim(const bytes& confirmation_tag)
{
const auto transcript = confirmed + tls::marshal(confirmation_tag);
interim = suite.digest().hash(transcript);
}
void
TranscriptHash::update_interim(const AuthenticatedContent& content_auth)
{
const auto transcript =
confirmed + content_auth.interim_transcript_hash_input();
interim = suite.digest().hash(transcript);
}
bool
operator==(const TranscriptHash& lhs, const TranscriptHash& rhs)
{
auto confirmed = (lhs.confirmed == rhs.confirmed);
auto interim = (lhs.interim == rhs.interim);
return confirmed && interim;
}
} // namespace mlspp

947
DPP/mlspp/src/messages.cpp Executable file
View File

@@ -0,0 +1,947 @@
#include <mls/key_schedule.h>
#include <mls/messages.h>
#include <mls/state.h>
#include <mls/treekem.h>
#include "grease.h"
namespace mlspp {
// Extensions
const Extension::Type ExternalPubExtension::type = ExtensionType::external_pub;
const Extension::Type RatchetTreeExtension::type = ExtensionType::ratchet_tree;
const Extension::Type ExternalSendersExtension::type =
ExtensionType::external_senders;
const Extension::Type SFrameParameters::type = ExtensionType::sframe_parameters;
const Extension::Type SFrameCapabilities::type =
ExtensionType::sframe_parameters;
bool
SFrameCapabilities::compatible(const SFrameParameters& params) const
{
return stdx::contains(cipher_suites, params.cipher_suite);
}
// GroupContext
GroupContext::GroupContext(CipherSuite cipher_suite_in,
bytes group_id_in,
epoch_t epoch_in,
bytes tree_hash_in,
bytes confirmed_transcript_hash_in,
ExtensionList extensions_in)
: cipher_suite(cipher_suite_in)
, group_id(std::move(group_id_in))
, epoch(epoch_in)
, tree_hash(std::move(tree_hash_in))
, confirmed_transcript_hash(std::move(confirmed_transcript_hash_in))
, extensions(std::move(extensions_in))
{
}
// GroupInfo
GroupInfo::GroupInfo(GroupContext group_context_in,
ExtensionList extensions_in,
bytes confirmation_tag_in)
: group_context(std::move(group_context_in))
, extensions(std::move(extensions_in))
, confirmation_tag(std::move(confirmation_tag_in))
, signer(0)
{
grease(extensions);
}
struct GroupInfoTBS
{
GroupContext group_context;
ExtensionList extensions;
bytes confirmation_tag;
LeafIndex signer;
TLS_SERIALIZABLE(group_context, extensions, confirmation_tag, signer)
};
bytes
GroupInfo::to_be_signed() const
{
return tls::marshal(
GroupInfoTBS{ group_context, extensions, confirmation_tag, signer });
}
void
GroupInfo::sign(const TreeKEMPublicKey& tree,
LeafIndex signer_index,
const SignaturePrivateKey& priv)
{
auto maybe_leaf = tree.leaf_node(signer_index);
if (!maybe_leaf) {
throw InvalidParameterError("Cannot sign from a blank leaf");
}
if (priv.public_key != opt::get(maybe_leaf).signature_key) {
throw InvalidParameterError("Bad key for index");
}
signer = signer_index;
signature = priv.sign(tree.suite, sign_label::group_info, to_be_signed());
}
bool
GroupInfo::verify(const TreeKEMPublicKey& tree) const
{
auto maybe_leaf = tree.leaf_node(signer);
if (!maybe_leaf) {
throw InvalidParameterError("Signer not found");
}
const auto& leaf = opt::get(maybe_leaf);
return verify(leaf.signature_key);
}
void
GroupInfo::sign(LeafIndex signer_index, const SignaturePrivateKey& priv)
{
signer = signer_index;
signature = priv.sign(
group_context.cipher_suite, sign_label::group_info, to_be_signed());
}
bool
GroupInfo::verify(const SignaturePublicKey& pub) const
{
return pub.verify(group_context.cipher_suite,
sign_label::group_info,
to_be_signed(),
signature);
}
// Welcome
Welcome::Welcome()
: cipher_suite(CipherSuite::ID::unknown)
{
}
Welcome::Welcome(CipherSuite suite,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks,
const GroupInfo& group_info)
: cipher_suite(suite)
, _joiner_secret(joiner_secret)
{
// Cache the list of PSK IDs
for (const auto& psk : psks) {
_psks.psks.push_back(psk.id);
}
// Pre-encrypt the GroupInfo
auto [key, nonce] = group_info_key_nonce(suite, joiner_secret, psks);
auto group_info_data = tls::marshal(group_info);
encrypted_group_info =
cipher_suite.hpke().aead.seal(key, nonce, {}, group_info_data);
}
std::optional<int>
Welcome::find(const KeyPackage& kp) const
{
auto ref = kp.ref();
for (size_t i = 0; i < secrets.size(); i++) {
if (ref == secrets[i].new_member) {
return static_cast<int>(i);
}
}
return std::nullopt;
}
void
Welcome::encrypt(const KeyPackage& kp, const std::optional<bytes>& path_secret)
{
auto gs = GroupSecrets{ _joiner_secret, std::nullopt, _psks };
if (path_secret) {
gs.path_secret = GroupSecrets::PathSecret{ opt::get(path_secret) };
}
auto gs_data = tls::marshal(gs);
auto enc_gs = kp.init_key.encrypt(
kp.cipher_suite, encrypt_label::welcome, encrypted_group_info, gs_data);
secrets.push_back({ kp.ref(), enc_gs });
}
GroupSecrets
Welcome::decrypt_secrets(int kp_index, const HPKEPrivateKey& init_priv) const
{
auto secrets_data =
init_priv.decrypt(cipher_suite,
encrypt_label::welcome,
encrypted_group_info,
secrets.at(kp_index).encrypted_group_secrets);
return tls::get<GroupSecrets>(secrets_data);
}
GroupInfo
Welcome::decrypt(const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks) const
{
auto [key, nonce] = group_info_key_nonce(cipher_suite, joiner_secret, psks);
auto group_info_data =
cipher_suite.hpke().aead.open(key, nonce, {}, encrypted_group_info);
if (!group_info_data) {
throw ProtocolError("Welcome decryption failed");
}
return tls::get<GroupInfo>(opt::get(group_info_data));
}
KeyAndNonce
Welcome::group_info_key_nonce(CipherSuite suite,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks)
{
auto welcome_secret =
KeyScheduleEpoch::welcome_secret(suite, joiner_secret, psks);
// XXX(RLB): These used to be done with ExpandWithLabel. Should we do that
// instead, for better domain separation? (In particular, including "mls10")
// That is what we do for the sender data key/nonce.
auto key =
suite.expand_with_label(welcome_secret, "key", {}, suite.key_size());
auto nonce =
suite.expand_with_label(welcome_secret, "nonce", {}, suite.nonce_size());
return { std::move(key), std::move(nonce) };
}
// Commit
std::optional<bytes>
Commit::valid_external() const
{
// External Commits MUST contain a path field (and is therefore a "full"
// Commit). The joiner is added at the leftmost free leaf node (just as if
// they were added with an Add proposal), and the path is calculated relative
// to that leaf node.
//
// The Commit MUST NOT include any proposals by reference, since an external
// joiner cannot determine the validity of proposals sent within the group
const auto all_by_value = stdx::all_of(proposals, [](const auto& p) {
return var::holds_alternative<Proposal>(p.content);
});
if (!path || !all_by_value) {
return std::nullopt;
}
const auto ext_init_ptr = stdx::find_if(proposals, [](const auto& p) {
const auto proposal = var::get<Proposal>(p.content);
return proposal.proposal_type() == ProposalType::external_init;
});
if (ext_init_ptr == proposals.end()) {
return std::nullopt;
}
const auto& ext_init_proposal = var::get<Proposal>(ext_init_ptr->content);
const auto& ext_init = var::get<ExternalInit>(ext_init_proposal.content);
return ext_init.kem_output;
}
// PublicMessage
Proposal::Type
Proposal::proposal_type() const
{
return tls::variant<ProposalType>::type(content).val;
}
SenderType
Sender::sender_type() const
{
return tls::variant<SenderType>::type(sender);
}
tls::ostream&
operator<<(tls::ostream& str, const GroupContentAuthData& obj)
{
switch (obj.content_type) {
case ContentType::proposal:
case ContentType::application:
return str << obj.signature;
case ContentType::commit:
return str << obj.signature << opt::get(obj.confirmation_tag);
default:
throw InvalidParameterError("Invalid content type");
}
}
tls::istream&
operator>>(tls::istream& str, GroupContentAuthData& obj)
{
switch (obj.content_type) {
case ContentType::proposal:
case ContentType::application:
return str >> obj.signature;
case ContentType::commit:
obj.confirmation_tag.emplace();
return str >> obj.signature >> opt::get(obj.confirmation_tag);
default:
throw InvalidParameterError("Invalid content type");
}
}
bool
operator==(const GroupContentAuthData& lhs, const GroupContentAuthData& rhs)
{
return lhs.content_type == rhs.content_type &&
lhs.signature == rhs.signature &&
lhs.confirmation_tag == rhs.confirmation_tag;
}
GroupContent::GroupContent(bytes group_id_in,
epoch_t epoch_in,
Sender sender_in,
bytes authenticated_data_in,
RawContent content_in)
: group_id(std::move(group_id_in))
, epoch(epoch_in)
, sender(sender_in)
, authenticated_data(std::move(authenticated_data_in))
, content(std::move(content_in))
{
}
GroupContent::GroupContent(bytes group_id_in,
epoch_t epoch_in,
Sender sender_in,
bytes authenticated_data_in,
ContentType content_type)
: group_id(std::move(group_id_in))
, epoch(epoch_in)
, sender(sender_in)
, authenticated_data(std::move(authenticated_data_in))
{
switch (content_type) {
case ContentType::commit:
content.emplace<Commit>();
break;
case ContentType::proposal:
content.emplace<Proposal>();
break;
case ContentType::application:
content.emplace<ApplicationData>();
break;
default:
throw InvalidParameterError("Invalid content type");
}
}
ContentType
GroupContent::content_type() const
{
return tls::variant<ContentType>::type(content);
}
AuthenticatedContent
AuthenticatedContent::sign(WireFormat wire_format,
GroupContent content,
CipherSuite suite,
const SignaturePrivateKey& sig_priv,
const std::optional<GroupContext>& context)
{
if (wire_format == WireFormat::mls_public_message &&
content.content_type() == ContentType::application) {
throw InvalidParameterError(
"Application data cannot be sent as PublicMessage");
}
auto content_auth = AuthenticatedContent{ wire_format, std::move(content) };
auto tbs = content_auth.to_be_signed(context);
content_auth.auth.signature =
sig_priv.sign(suite, sign_label::mls_content, tbs);
return content_auth;
}
bool
AuthenticatedContent::verify(CipherSuite suite,
const SignaturePublicKey& sig_pub,
const std::optional<GroupContext>& context) const
{
if (wire_format == WireFormat::mls_public_message &&
content.content_type() == ContentType::application) {
return false;
}
auto tbs = to_be_signed(context);
return sig_pub.verify(suite, sign_label::mls_content, tbs, auth.signature);
}
struct ConfirmedTranscriptHashInput
{
WireFormat wire_format;
const GroupContent& content;
const bytes& signature;
TLS_SERIALIZABLE(wire_format, content, signature);
};
struct InterimTranscriptHashInput
{
const bytes& confirmation_tag;
TLS_SERIALIZABLE(confirmation_tag);
};
bytes
AuthenticatedContent::confirmed_transcript_hash_input() const
{
return tls::marshal(ConfirmedTranscriptHashInput{
wire_format,
content,
auth.signature,
});
}
bytes
AuthenticatedContent::interim_transcript_hash_input() const
{
return tls::marshal(
InterimTranscriptHashInput{ opt::get(auth.confirmation_tag) });
}
void
AuthenticatedContent::set_confirmation_tag(const bytes& confirmation_tag)
{
auth.confirmation_tag = confirmation_tag;
}
bool
AuthenticatedContent::check_confirmation_tag(
const bytes& confirmation_tag) const
{
return confirmation_tag == opt::get(auth.confirmation_tag);
}
tls::ostream&
operator<<(tls::ostream& str, const AuthenticatedContent& obj)
{
return str << obj.wire_format << obj.content << obj.auth;
}
tls::istream&
operator>>(tls::istream& str, AuthenticatedContent& obj)
{
str >> obj.wire_format >> obj.content;
obj.auth.content_type = obj.content.content_type();
return str >> obj.auth;
}
bool
operator==(const AuthenticatedContent& lhs, const AuthenticatedContent& rhs)
{
return lhs.wire_format == rhs.wire_format && lhs.content == rhs.content &&
lhs.auth == rhs.auth;
}
AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in,
GroupContent content_in)
: wire_format(wire_format_in)
, content(std::move(content_in))
{
auth.content_type = content.content_type();
}
AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in,
GroupContent content_in,
GroupContentAuthData auth_in)
: wire_format(wire_format_in)
, content(std::move(content_in))
, auth(std::move(auth_in))
{
}
const AuthenticatedContent&
ValidatedContent::authenticated_content() const
{
return content_auth;
}
ValidatedContent::ValidatedContent(AuthenticatedContent content_auth_in)
: content_auth(std::move(content_auth_in))
{
}
bool
operator==(const ValidatedContent& lhs, const ValidatedContent& rhs)
{
return lhs.content_auth == rhs.content_auth;
}
struct GroupContentTBS
{
WireFormat wire_format = WireFormat::reserved;
const GroupContent& content;
const std::optional<GroupContext>& context;
};
static tls::ostream&
operator<<(tls::ostream& str, const GroupContentTBS& obj)
{
str << ProtocolVersion::mls10 << obj.wire_format << obj.content;
switch (obj.content.sender.sender_type()) {
case SenderType::member:
case SenderType::new_member_commit:
str << opt::get(obj.context);
break;
case SenderType::external:
case SenderType::new_member_proposal:
break;
default:
throw InvalidParameterError("Invalid sender type");
}
return str;
}
bytes
AuthenticatedContent::to_be_signed(
const std::optional<GroupContext>& context) const
{
return tls::marshal(GroupContentTBS{
wire_format,
content,
context,
});
}
PublicMessage
PublicMessage::protect(AuthenticatedContent content_auth,
CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context)
{
auto pt = PublicMessage(std::move(content_auth));
// Add the membership_mac if required
switch (pt.content.sender.sender_type()) {
case SenderType::member:
pt.membership_tag =
pt.membership_mac(suite, opt::get(membership_key), context);
break;
default:
break;
}
return pt;
}
std::optional<ValidatedContent>
PublicMessage::unprotect(CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context) const
{
// Verify the membership_tag if the message was sent within the group
switch (content.sender.sender_type()) {
case SenderType::member: {
auto candidate = membership_mac(suite, opt::get(membership_key), context);
if (candidate != opt::get(membership_tag)) {
return std::nullopt;
}
break;
}
default:
break;
}
return { { AuthenticatedContent{
WireFormat::mls_public_message,
content,
auth,
} } };
}
bool
PublicMessage::contains(const AuthenticatedContent& content_auth) const
{
return content == content_auth.content && auth == content_auth.auth;
}
AuthenticatedContent
PublicMessage::authenticated_content() const
{
auto auth_content = AuthenticatedContent{};
auth_content.wire_format = WireFormat::mls_public_message;
auth_content.content = content;
auth_content.auth = auth;
return auth_content;
}
PublicMessage::PublicMessage(AuthenticatedContent content_auth)
: content(std::move(content_auth.content))
, auth(std::move(content_auth.auth))
{
if (content_auth.wire_format != WireFormat::mls_public_message) {
throw InvalidParameterError("Wire format mismatch (not mls_plaintext)");
}
}
struct GroupContentTBM
{
GroupContentTBS content_tbs;
GroupContentAuthData auth;
TLS_SERIALIZABLE(content_tbs, auth);
};
bytes
PublicMessage::membership_mac(CipherSuite suite,
const bytes& membership_key,
const std::optional<GroupContext>& context) const
{
auto tbm = tls::marshal(GroupContentTBM{
{ WireFormat::mls_public_message, content, context },
auth,
});
return suite.digest().hmac(membership_key, tbm);
}
tls::ostream&
operator<<(tls::ostream& str, const PublicMessage& obj)
{
switch (obj.content.sender.sender_type()) {
case SenderType::member:
return str << obj.content << obj.auth << opt::get(obj.membership_tag);
case SenderType::external:
case SenderType::new_member_proposal:
case SenderType::new_member_commit:
return str << obj.content << obj.auth;
default:
throw InvalidParameterError("Invalid sender type");
}
}
tls::istream&
operator>>(tls::istream& str, PublicMessage& obj)
{
str >> obj.content;
obj.auth.content_type = obj.content.content_type();
str >> obj.auth;
if (obj.content.sender.sender_type() == SenderType::member) {
obj.membership_tag.emplace();
str >> opt::get(obj.membership_tag);
}
return str;
}
bool
operator==(const PublicMessage& lhs, const PublicMessage& rhs)
{
return lhs.content == rhs.content && lhs.auth == rhs.auth &&
lhs.membership_tag == rhs.membership_tag;
}
bool
operator!=(const PublicMessage& lhs, const PublicMessage& rhs)
{
return !(lhs == rhs);
}
static bytes
marshal_ciphertext_content(const GroupContent& content,
const GroupContentAuthData& auth,
size_t padding_size)
{
auto w = tls::ostream{};
var::visit([&w](const auto& val) { w << val; }, content.content);
w << auth;
w.write_raw(bytes(padding_size, 0));
return w.bytes();
}
static void
unmarshal_ciphertext_content(const bytes& content_pt,
GroupContent& content,
GroupContentAuthData& auth)
{
auto r = tls::istream(content_pt);
var::visit([&r](auto& val) { r >> val; }, content.content);
r >> auth;
const auto padding = r.bytes();
const auto nonzero = [](const auto& x) { return x != 0; };
if (stdx::any_of(padding, nonzero)) {
throw ProtocolError("Malformed AuthenticatedContentTBE padding");
}
}
struct ContentAAD
{
const bytes& group_id;
const epoch_t epoch;
const ContentType content_type;
const bytes& authenticated_data;
TLS_SERIALIZABLE(group_id, epoch, content_type, authenticated_data)
};
struct SenderData
{
LeafIndex sender{ 0 };
uint32_t generation{ 0 };
ReuseGuard reuse_guard{ 0, 0, 0, 0 };
TLS_SERIALIZABLE(sender, generation, reuse_guard)
};
struct SenderDataAAD
{
const bytes& group_id;
const epoch_t epoch;
const ContentType content_type;
TLS_SERIALIZABLE(group_id, epoch, content_type)
};
PrivateMessage
PrivateMessage::protect(AuthenticatedContent content_auth,
CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret,
size_t padding_size)
{
// Pull keys from the secret tree
auto index =
var::get<MemberSender>(content_auth.content.sender.sender).sender;
auto content_type = content_auth.content.content_type();
auto [generation, reuse_guard, content_keys] = keys.next(content_type, index);
// Encrypt the content
auto content_pt = marshal_ciphertext_content(
content_auth.content, content_auth.auth, padding_size);
auto content_aad = tls::marshal(ContentAAD{
content_auth.content.group_id,
content_auth.content.epoch,
content_auth.content.content_type(),
content_auth.content.authenticated_data,
});
auto content_ct = suite.hpke().aead.seal(
content_keys.key, content_keys.nonce, content_aad, content_pt);
// Encrypt the sender data
auto sender_index =
var::get<MemberSender>(content_auth.content.sender.sender).sender;
auto sender_data_pt = tls::marshal(SenderData{
sender_index,
generation,
reuse_guard,
});
auto sender_data_aad = tls::marshal(SenderDataAAD{
content_auth.content.group_id,
content_auth.content.epoch,
content_auth.content.content_type(),
});
auto sender_data_keys =
KeyScheduleEpoch::sender_data_keys(suite, sender_data_secret, content_ct);
auto sender_data_ct = suite.hpke().aead.seal(sender_data_keys.key,
sender_data_keys.nonce,
sender_data_aad,
sender_data_pt);
return PrivateMessage{
std::move(content_auth.content),
std::move(sender_data_ct),
std::move(content_ct),
};
}
std::optional<ValidatedContent>
PrivateMessage::unprotect(CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret) const
{
// Decrypt and parse the sender data
auto sender_data_keys =
KeyScheduleEpoch::sender_data_keys(suite, sender_data_secret, ciphertext);
auto sender_data_aad = tls::marshal(SenderDataAAD{
group_id,
epoch,
content_type,
});
auto sender_data_pt = suite.hpke().aead.open(sender_data_keys.key,
sender_data_keys.nonce,
sender_data_aad,
encrypted_sender_data);
if (!sender_data_pt) {
return std::nullopt;
}
auto sender_data = tls::get<SenderData>(opt::get(sender_data_pt));
if (!keys.has_leaf(sender_data.sender)) {
return std::nullopt;
}
// Decrypt the content
auto content_keys = keys.get(content_type,
sender_data.sender,
sender_data.generation,
sender_data.reuse_guard);
keys.erase(content_type, sender_data.sender, sender_data.generation);
auto content_aad = tls::marshal(ContentAAD{
group_id,
epoch,
content_type,
authenticated_data,
});
auto content_pt = suite.hpke().aead.open(
content_keys.key, content_keys.nonce, content_aad, ciphertext);
if (!content_pt) {
return std::nullopt;
}
// Parse the content
auto content = GroupContent{ group_id,
epoch,
{ MemberSender{ sender_data.sender } },
authenticated_data,
content_type };
auto auth = GroupContentAuthData{ content_type, {}, {} };
unmarshal_ciphertext_content(opt::get(content_pt), content, auth);
return { { AuthenticatedContent{
WireFormat::mls_private_message,
std::move(content),
std::move(auth),
} } };
}
PrivateMessage::PrivateMessage(GroupContent content,
bytes encrypted_sender_data_in,
bytes ciphertext_in)
: group_id(std::move(content.group_id))
, epoch(content.epoch)
, content_type(content.content_type())
, authenticated_data(std::move(content.authenticated_data))
, encrypted_sender_data(std::move(encrypted_sender_data_in))
, ciphertext(std::move(ciphertext_in))
{
}
bytes
MLSMessage::group_id() const
{
return var::visit(
overloaded{
[](const PublicMessage& pt) -> bytes { return pt.get_group_id(); },
[](const PrivateMessage& ct) -> bytes { return ct.get_group_id(); },
[](const GroupInfo& gi) -> bytes { return gi.group_context.group_id; },
[](const auto& /* unused */) -> bytes {
throw InvalidParameterError("MLSMessage has no group_id");
},
},
message);
}
epoch_t
MLSMessage::epoch() const
{
return var::visit(
overloaded{
[](const PublicMessage& pt) -> epoch_t { return pt.get_epoch(); },
[](const PrivateMessage& pt) -> epoch_t { return pt.get_epoch(); },
[](const auto& /* unused */) -> epoch_t {
throw InvalidParameterError("MLSMessage has no epoch");
},
},
message);
}
WireFormat
MLSMessage::wire_format() const
{
return tls::variant<WireFormat>::type(message);
}
MLSMessage::MLSMessage(PublicMessage public_message)
: message(std::move(public_message))
{
}
MLSMessage::MLSMessage(PrivateMessage private_message)
: message(std::move(private_message))
{
}
MLSMessage::MLSMessage(Welcome welcome)
: message(std::move(welcome))
{
}
MLSMessage::MLSMessage(GroupInfo group_info)
: message(std::move(group_info))
{
}
MLSMessage::MLSMessage(KeyPackage key_package)
: message(std::move(key_package))
{
}
MLSMessage
external_proposal(CipherSuite suite,
const bytes& group_id,
epoch_t epoch,
const Proposal& proposal,
uint32_t signer_index,
const SignaturePrivateKey& sig_priv)
{
switch (proposal.proposal_type()) {
// These proposal types are OK
case ProposalType::add:
case ProposalType::remove:
case ProposalType::psk:
case ProposalType::reinit:
case ProposalType::group_context_extensions:
break;
// These proposal types are forbidden
case ProposalType::invalid:
case ProposalType::update:
case ProposalType::external_init:
default:
throw ProtocolError("External proposal has invalid type");
}
auto content = GroupContent{ group_id,
epoch,
{ ExternalSenderIndex{ signer_index } },
{ /* no authenticated data */ },
{ proposal } };
auto content_auth = AuthenticatedContent::sign(
WireFormat::mls_public_message, std::move(content), suite, sig_priv, {});
return PublicMessage::protect(std::move(content_auth), suite, {}, {});
}
} // namespace mlspp

437
DPP/mlspp/src/session.cpp Executable file
View File

@@ -0,0 +1,437 @@
#include <mls/messages.h>
#include <mls/session.h>
#include <deque>
namespace mlspp {
///
/// Inner struct declarations for PendingJoin and Session
///
struct PendingJoin::Inner
{
const CipherSuite suite;
const HPKEPrivateKey init_priv;
const HPKEPrivateKey leaf_priv;
const SignaturePrivateKey sig_priv;
const KeyPackage key_package;
Inner(CipherSuite suite_in,
SignaturePrivateKey sig_priv_in,
Credential cred_in);
static PendingJoin create(CipherSuite suite,
SignaturePrivateKey sig_priv,
Credential cred);
};
struct Session::Inner
{
std::deque<State> history;
std::map<bytes, State> outbound_cache;
bool encrypt_handshake{ false };
explicit Inner(State state);
static Session begin(CipherSuite suite,
const bytes& group_id,
const HPKEPrivateKey& leaf_priv,
const SignaturePrivateKey& sig_priv,
const LeafNode& leaf_node);
static Session join(const HPKEPrivateKey& init_priv,
const HPKEPrivateKey& leaf_priv,
const SignaturePrivateKey& sig_priv,
const KeyPackage& key_package,
const bytes& welcome_data);
bytes fresh_secret() const;
MLSMessage import_handshake(const bytes& encoded) const;
State& for_epoch(epoch_t epoch);
};
///
/// Client
///
Client::Client(CipherSuite suite_in,
SignaturePrivateKey sig_priv_in,
Credential cred_in)
: suite(suite_in)
, sig_priv(std::move(sig_priv_in))
, cred(std::move(cred_in))
{
}
Session
Client::begin_session(const bytes& group_id) const
{
auto leaf_priv = HPKEPrivateKey::generate(suite);
auto leaf_node = LeafNode(suite,
leaf_priv.public_key,
sig_priv.public_key,
cred,
Capabilities::create_default(),
Lifetime::create_default(),
{},
sig_priv);
return Session::Inner::begin(suite, group_id, leaf_priv, sig_priv, leaf_node);
}
PendingJoin
Client::start_join() const
{
return PendingJoin::Inner::create(suite, sig_priv, cred);
}
///
/// PendingJoin
///
PendingJoin::Inner::Inner(CipherSuite suite_in,
SignaturePrivateKey sig_priv_in,
Credential cred_in)
: suite(suite_in)
, init_priv(HPKEPrivateKey::generate(suite))
, leaf_priv(HPKEPrivateKey::generate(suite))
, sig_priv(std::move(sig_priv_in))
, key_package(suite,
init_priv.public_key,
LeafNode(suite,
leaf_priv.public_key,
sig_priv.public_key,
std::move(cred_in),
Capabilities::create_default(),
Lifetime::create_default(),
{},
sig_priv),
{},
sig_priv)
{
}
PendingJoin
PendingJoin::Inner::create(CipherSuite suite,
SignaturePrivateKey sig_priv,
Credential cred)
{
auto inner =
std::make_unique<Inner>(suite, std::move(sig_priv), std::move(cred));
return { inner.release() };
}
PendingJoin::PendingJoin(PendingJoin&& other) noexcept = default;
PendingJoin&
PendingJoin::operator=(PendingJoin&& other) noexcept = default;
PendingJoin::~PendingJoin() = default;
PendingJoin::PendingJoin(Inner* inner_in)
: inner(inner_in)
{
}
bytes
PendingJoin::key_package() const
{
return tls::marshal(inner->key_package);
}
Session
PendingJoin::complete(const bytes& welcome) const
{
return Session::Inner::join(inner->init_priv,
inner->leaf_priv,
inner->sig_priv,
inner->key_package,
welcome);
}
///
/// Session
///
Session::Inner::Inner(State state)
: history{ std::move(state) }
, encrypt_handshake(true)
{
}
Session
Session::Inner::begin(CipherSuite suite,
const bytes& group_id,
const HPKEPrivateKey& leaf_priv,
const SignaturePrivateKey& sig_priv,
const LeafNode& leaf_node)
{
auto state = State(group_id, suite, leaf_priv, sig_priv, leaf_node, {});
auto inner = std::make_unique<Inner>(state);
return { inner.release() };
}
Session
Session::Inner::join(const HPKEPrivateKey& init_priv,
const HPKEPrivateKey& leaf_priv,
const SignaturePrivateKey& sig_priv,
const KeyPackage& key_package,
const bytes& welcome_data)
{
auto welcome = tls::get<Welcome>(welcome_data);
auto state = State(
init_priv, leaf_priv, sig_priv, key_package, welcome, std::nullopt, {});
auto inner = std::make_unique<Inner>(state);
return { inner.release() };
}
bytes
Session::Inner::fresh_secret() const
{
const auto suite = history.front().cipher_suite();
return random_bytes(suite.secret_size());
}
MLSMessage
Session::Inner::import_handshake(const bytes& encoded) const
{
auto msg = tls::get<MLSMessage>(encoded);
switch (msg.wire_format()) {
case WireFormat::mls_public_message:
if (encrypt_handshake) {
throw ProtocolError("Handshake not encrypted as required");
}
return msg;
case WireFormat::mls_private_message: {
if (!encrypt_handshake) {
throw ProtocolError("Unexpected handshake encryption");
}
return msg;
}
default:
throw InvalidParameterError("Illegal wire format");
}
}
State&
Session::Inner::for_epoch(epoch_t epoch)
{
for (auto& state : history) {
if (state.epoch() == epoch) {
return state;
}
}
throw MissingStateError("No state for epoch");
}
Session::Session(Session&& other) noexcept = default;
Session&
Session::operator=(Session&& other) noexcept = default;
Session::~Session() = default;
Session::Session(Inner* inner_in)
: inner(inner_in)
{
}
void
Session::encrypt_handshake(bool enabled)
{
inner->encrypt_handshake = enabled;
}
bytes
Session::add(const bytes& key_package_data)
{
auto key_package = tls::get<KeyPackage>(key_package_data);
auto proposal = inner->history.front().add(
key_package, { inner->encrypt_handshake, {}, 0 });
return tls::marshal(proposal);
}
bytes
Session::update()
{
auto leaf_secret = inner->fresh_secret();
auto leaf_priv = HPKEPrivateKey::generate(cipher_suite());
auto proposal = inner->history.front().update(
std::move(leaf_priv), {}, { inner->encrypt_handshake, {}, 0 });
return tls::marshal(proposal);
}
bytes
Session::remove(uint32_t index)
{
auto proposal = inner->history.front().remove(
RosterIndex{ index }, { inner->encrypt_handshake, {}, 0 });
return tls::marshal(proposal);
}
std::tuple<bytes, bytes>
Session::commit(const bytes& proposal)
{
return commit(std::vector<bytes>{ proposal });
}
std::tuple<bytes, bytes>
Session::commit(const std::vector<bytes>& proposals)
{
auto provisional_state = inner->history.front();
for (const auto& proposal_data : proposals) {
auto msg = inner->import_handshake(proposal_data);
auto maybe_state = provisional_state.handle(msg);
if (maybe_state) {
throw InvalidParameterError("Invalid proposal; actually a commit");
}
}
inner->history.front() = std::move(provisional_state);
return commit();
}
std::tuple<bytes, bytes>
Session::commit()
{
auto commit_secret = inner->fresh_secret();
auto encrypt = inner->encrypt_handshake;
auto [commit, welcome, new_state] = inner->history.front().commit(
commit_secret, CommitOpts{ {}, true, encrypt, {} }, { encrypt, {}, 0 });
auto commit_msg = tls::marshal(commit);
auto welcome_msg = tls::marshal(welcome);
inner->outbound_cache.insert({ commit_msg, new_state });
return std::make_tuple(welcome_msg, commit_msg);
}
bool
Session::handle(const bytes& handshake_data)
{
auto msg = inner->import_handshake(handshake_data);
auto maybe_cached_state = std::optional<State>{};
auto node = inner->outbound_cache.extract(handshake_data);
if (!node.empty()) {
maybe_cached_state = node.mapped();
}
auto maybe_next_state =
inner->history.front().handle(msg, maybe_cached_state);
if (!maybe_next_state) {
return false;
}
inner->history.emplace_front(opt::get(maybe_next_state));
return true;
}
epoch_t
Session::epoch() const
{
return inner->history.front().epoch();
}
LeafIndex
Session::index() const
{
return inner->history.front().index();
}
CipherSuite
Session::cipher_suite() const
{
return inner->history.front().cipher_suite();
}
const ExtensionList&
Session::extensions() const
{
return inner->history.front().extensions();
}
const TreeKEMPublicKey&
Session::tree() const
{
return inner->history.front().tree();
}
bytes
Session::do_export(const std::string& label,
const bytes& context,
size_t size) const
{
return inner->history.front().do_export(label, context, size);
}
GroupInfo
Session::group_info() const
{
return inner->history.front().group_info(true);
}
std::vector<LeafNode>
Session::roster() const
{
return inner->history.front().roster();
}
bytes
Session::epoch_authenticator() const
{
return inner->history.front().epoch_authenticator();
}
bytes
Session::protect(const bytes& plaintext)
{
auto msg = inner->history.front().protect({}, plaintext, 0);
return tls::marshal(msg);
}
// TODO(rlb@ipv.sx): It would be good to expose identity information
// here, since ciphertexts are authenticated per sender. Who sent
// this ciphertext?
bytes
Session::unprotect(const bytes& ciphertext)
{
auto ciphertext_obj = tls::get<MLSMessage>(ciphertext);
auto& state = inner->for_epoch(ciphertext_obj.epoch());
auto [aad, pt] = state.unprotect(ciphertext_obj);
silence_unused(aad);
return pt;
}
bool
operator==(const Session& lhs, const Session& rhs)
{
if (lhs.inner->encrypt_handshake != rhs.inner->encrypt_handshake) {
return false;
}
auto size = std::min(lhs.inner->history.size(), rhs.inner->history.size());
for (size_t i = 0; i < size; i += 1) {
if (lhs.inner->history.at(i) != rhs.inner->history.at(i)) {
return false;
}
}
return true;
}
bool
operator!=(const Session& lhs, const Session& rhs)
{
return !(lhs == rhs);
}
} // namespace mlspp

2219
DPP/mlspp/src/state.cpp Executable file

File diff suppressed because it is too large Load Diff

223
DPP/mlspp/src/tree_math.cpp Executable file
View File

@@ -0,0 +1,223 @@
#include "mls/tree_math.h"
#include "mls/common.h"
#include <algorithm>
static const uint32_t one = 0x01;
static uint32_t
log2(uint32_t x)
{
if (x == 0) {
return 0;
}
uint32_t k = 0;
while ((x >> k) > 0) {
k += 1;
}
return k - 1;
}
namespace mlspp {
LeafCount::LeafCount(const NodeCount w)
{
if (w.val == 0) {
val = 0;
return;
}
if ((w.val & one) == 0) {
throw InvalidParameterError("Only odd node counts describe trees");
}
val = (w.val >> one) + 1;
}
LeafCount
LeafCount::full(const LeafCount n)
{
auto w = uint32_t(1);
while (w < n.val) {
w <<= 1U;
}
return LeafCount{ w };
}
NodeCount::NodeCount(const LeafCount n)
: UInt32(2 * (n.val - 1) + 1)
{
}
LeafIndex::LeafIndex(NodeIndex x)
: UInt32(0)
{
if (x.val % 2 == 1) {
throw InvalidParameterError("Only even node indices describe leaves");
}
val = x.val >> 1; // NOLINT(hicpp-signed-bitwise)
}
NodeIndex
LeafIndex::ancestor(LeafIndex other) const
{
auto ln = NodeIndex(*this);
auto rn = NodeIndex(other);
if (ln == rn) {
return ln;
}
uint8_t k = 0;
while (ln != rn) {
ln.val = ln.val >> 1U;
rn.val = rn.val >> 1U;
k += 1;
}
const uint32_t prefix = ln.val << k;
const uint32_t stop = (1U << uint8_t(k - 1));
return NodeIndex{ prefix + (stop - 1) };
}
NodeIndex::NodeIndex(LeafIndex x)
: UInt32(2 * x.val)
{
}
NodeIndex
NodeIndex::root(LeafCount n)
{
if (n.val == 0) {
throw std::runtime_error("Root for zero-size tree is undefined");
}
auto w = NodeCount(n);
return NodeIndex{ (one << log2(w.val)) - 1 };
}
bool
NodeIndex::is_leaf() const
{
return val % 2 == 0;
}
bool
NodeIndex::is_below(NodeIndex other) const
{
auto lx = level();
auto ly = other.level();
return lx <= ly && (val >> (ly + 1) == other.val >> (ly + 1));
}
NodeIndex
NodeIndex::left() const
{
if (is_leaf()) {
return *this;
}
// The clang analyzer doesn't realize that is_leaf() assures that level >= 1
// NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
return NodeIndex{ val ^ (one << (level() - 1)) };
}
NodeIndex
NodeIndex::right() const
{
if (is_leaf()) {
return *this;
}
return NodeIndex{ val ^ (uint32_t(0x03) << (level() - 1)) };
}
NodeIndex
NodeIndex::parent() const
{
auto k = level();
return NodeIndex{ (val | (one << k)) & ~(one << (k + 1)) };
}
NodeIndex
NodeIndex::sibling() const
{
return sibling(parent());
}
NodeIndex
NodeIndex::sibling(NodeIndex ancestor) const
{
if (!is_below(ancestor)) {
throw InvalidParameterError("Node is not below claimed ancestor");
}
auto l = ancestor.left();
auto r = ancestor.right();
if (is_below(l)) {
return r;
}
return l;
}
std::vector<NodeIndex>
NodeIndex::dirpath(LeafCount n)
{
if (val >= NodeCount(n).val) {
throw InvalidParameterError("Request for dirpath outside of tree");
}
auto d = std::vector<NodeIndex>{};
auto r = root(n);
if (*this == r) {
return d;
}
auto p = parent();
while (p.val != r.val) {
d.push_back(p);
p = p.parent();
}
// Include the root except in a one-member tree
if (val != r.val) {
d.push_back(p);
}
return d;
}
std::vector<NodeIndex>
NodeIndex::copath(LeafCount n)
{
auto d = dirpath(n);
if (d.empty()) {
return {};
}
// Prepend leaf; omit root
d.insert(d.begin(), *this);
d.pop_back();
return stdx::transform<NodeIndex>(d, [](auto x) { return x.sibling(); });
}
uint32_t
NodeIndex::level() const
{
if ((val & one) == 0) {
return 0;
}
uint32_t k = 0;
while (((val >> k) & one) == 1) {
k += 1;
}
return k;
}
} // namespace mlspp

1127
DPP/mlspp/src/treekem.cpp Executable file

File diff suppressed because it is too large Load Diff