add basic discord support

This commit is contained in:
2024-11-09 01:58:06 +03:00
parent c3e7a9c92d
commit 8965b7ee90
869 changed files with 191278 additions and 7 deletions

117
DPP/mlspp/CMakeLists.txt Executable file
View File

@@ -0,0 +1,117 @@
cmake_minimum_required(VERSION 3.13)
project(mlspp
VERSION 0.1
LANGUAGES CXX
)
option(TESTING "Build tests" OFF)
option(CLANG_TIDY "Perform linting with clang-tidy" OFF)
option(SANITIZERS "Enable sanitizers" OFF)
option(MLS_NAMESPACE_SUFFIX "Namespace Suffix for CXX and CMake Export")
option(DISABLE_GREASE "Disables the inclusion of MLS protocol recommended GREASE values" ON)
option(REQUIRE_BORINGSSL "Require BoringSSL instead of OpenSSL" OFF)
if(MLS_NAMESPACE_SUFFIX)
set(MLS_CXX_NAMESPACE "mls_${MLS_NAMESPACE_SUFFIX}" CACHE STRING "Top-level Namespace for CXX")
set(MLS_EXPORT_NAMESPACE "MLSPP${MLS_NAMESPACE_SUFFIX}" CACHE STRING "Namespace for CMake Export")
else()
set(MLS_CXX_NAMESPACE "../include/dpp/mlspp/mls" CACHE STRING "Top-level Namespace for CXX")
set(MLS_EXPORT_NAMESPACE "MLSPP" CACHE STRING "Namespace for CMake Export")
endif()
message(STATUS "CXX Namespace: ${MLS_CXX_NAMESPACE}")
message(STATUS "CMake Export Namespace: ${MLS_EXPORT_NAMESPACE}")
###
### Global Config
###
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
configure_file(
"cmake/namespace.h.in"
"${CMAKE_CURRENT_SOURCE_DIR}/include/namespace.h"
@ONLY
)
include(CheckCXXCompilerFlag)
include(CMakePackageConfigHelpers)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "GNU")
add_compile_options(-Wall -fPIC)
elseif(MSVC)
add_compile_options(/W2)
add_definitions(-DWINDOWS)
# MSVC helpfully recommends safer equivalents for things like
# getenv, but they are not portable.
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
endif()
if("$ENV{MACOSX_DEPLOYMENT_TARGET}" STREQUAL "10.11")
add_compile_options(-DVARIANT_COMPAT)
endif()
add_compile_options(-DDISABLE_GREASE)
###
### Dependencies
###
# Configure vcpkg to only build release libraries
set(VCPKG_BUILD_TYPE release)
if (${OPENSSL_VERSION} VERSION_GREATER_EQUAL 3)
add_compile_definitions(WITH_OPENSSL3)
elseif(${OPENSSL_VERSION} VERSION_LESS 1.1.1)
message(FATAL_ERROR "OpenSSL 1.1.1 or greater is required")
endif()
message(STATUS "OpenSSL Found: ${OPENSSL_VERSION}")
message(STATUS "OpenSSL Include: ${OPENSSL_INCLUDE_DIR}")
message(STATUS "OpenSSL Libraries: ${OPENSSL_LIBRARIES}")
# Internal libraries
add_subdirectory(lib)
###
### Library Config
###
set(LIB_NAME "${PROJECT_NAME}")
file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
add_library(${LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES})
add_dependencies(${LIB_NAME} bytes tls_syntax hpke)
target_link_libraries(${LIB_NAME} bytes tls_syntax hpke)
target_include_directories(${LIB_NAME}
PUBLIC
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include/${PROJECT_NAME}>
PRIVATE
${OPENSSL_INCLUDE_DIR}
)
###
### Exports
###
set(CMAKE_EXPORT_PACKAGE_REGISTRY ON)
export(PACKAGE ${MLS_EXPORT_NAMESPACE})
configure_package_config_file(cmake/config.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/${MLS_EXPORT_NAMESPACE}Config.cmake
INSTALL_DESTINATION ${CMAKE_INSTALL_DATADIR}/${MLS_EXPORT_NAMESPACE}
NO_SET_AND_CHECK_MACRO)
write_basic_package_version_file(
${CMAKE_CURRENT_BINARY_DIR}/${MLS_EXPORT_NAMESPACE}ConfigVersion.cmake
VERSION ${PROJECT_VERSION}
COMPATIBILITY SameMajorVersion)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/bytes/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/hpke/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/mls_vectors/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/tls_syntax/include")

25
DPP/mlspp/LICENSE Executable file
View File

@@ -0,0 +1,25 @@
BSD 2-Clause License
Copyright (c) 2018, Cisco Systems
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,4 @@
@PACKAGE_INIT@
include(${CMAKE_CURRENT_LIST_DIR}/@MLS_EXPORT_NAMESPACE@Targets.cmake)
check_required_components(mlspp)

4
DPP/mlspp/cmake/namespace.h.in Executable file
View File

@@ -0,0 +1,4 @@
#pragma once
// Configurable top-level MLS namespace
#define MLS_NAMESPACE @MLS_CXX_NAMESPACE@

274
DPP/mlspp/include/mls/common.h Executable file
View File

@@ -0,0 +1,274 @@
#pragma once
#include <array>
#include <iomanip>
#include <iterator>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
using namespace std::literals::string_literals;
// Expose the bytes library globally
#include <bytes/bytes.h>
using namespace mlspp::bytes_ns;
// Expose the compatibility library globally
#include <tls/compat.h>
namespace var = mlspp::tls::var;
namespace opt = mlspp::tls::opt;
namespace mlspp {
// Make variant equality work in the same way as optional equality, with
// automatic unwrapping. In other words
//
// v == T(x) <=> hold_alternative<T>(v) && get<T>(v) == x
//
// For consistency, we also define symmetric and negated version. In this
// house, we obey the symmetric law of equivalence relations!
template<typename T, typename... Ts>
bool
operator==(const var::variant<Ts...>& v, const T& t)
{
return var::visit(
[&](const auto& arg) {
using U = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<U, T>) {
return arg == t;
} else {
return false;
}
},
v);
}
template<typename T, typename... Ts>
bool
operator==(const T& t, const var::variant<Ts...>& v)
{
return v == t;
}
template<typename T, typename... Ts>
bool
operator!=(const var::variant<Ts...>& v, const T& t)
{
return !(v == t);
}
template<typename T, typename... Ts>
bool
operator!=(const T& t, const var::variant<Ts...>& v)
{
return !(v == t);
}
using epoch_t = uint64_t;
///
/// Get the current system clock time in the format MLS expects
///
uint64_t
seconds_since_epoch();
///
/// Easy construction of overloaded lambdas
///
template<class... Ts>
struct overloaded : Ts...
{
using Ts::operator()...;
// XXX(RLB) MSVC has a bug where it incorrectly computes the size of this
// type. Microsoft claims they have fixed it in the latest MSVC, and GitHub
// claims they are running a version with the fix. But in practice, we still
// hit it. Including this dummy variable is a work-around.
//
// https://developercommunity.visualstudio.com/t/runtime-stack-corruption-using-stdvisit/346200
int dummy = 0;
};
// clang-format off
// XXX(RLB): For some reason, different versions of clang-format disagree on how
// this should be formatted. Probably because it's new syntax with C++17?
// Exempting it from clang-format for now.
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
// clang-format on
///
/// Auto-generate equality and inequality operators for TLS-serializable things
///
template<typename T>
inline typename std::enable_if<T::_tls_serializable, bool>::type
operator==(const T& lhs, const T& rhs)
{
return lhs._tls_fields_w() == rhs._tls_fields_w();
}
template<typename T>
inline typename std::enable_if<T::_tls_serializable, bool>::type
operator!=(const T& lhs, const T& rhs)
{
return lhs._tls_fields_w() != rhs._tls_fields_w();
}
///
/// Error types
///
// The `using parent = X` / `using parent::parent` construction here
// imports the constructors of the parent.
class NotImplementedError : public std::exception
{
public:
using parent = std::exception;
using parent::parent;
};
class ProtocolError : public std::runtime_error
{
public:
using parent = std::runtime_error;
using parent::parent;
};
class IncompatibleNodesError : public std::invalid_argument
{
public:
using parent = std::invalid_argument;
using parent::parent;
};
class InvalidParameterError : public std::invalid_argument
{
public:
using parent = std::invalid_argument;
using parent::parent;
};
class InvalidPathError : public std::invalid_argument
{
public:
using parent = std::invalid_argument;
using parent::parent;
};
class InvalidIndexError : public std::invalid_argument
{
public:
using parent = std::invalid_argument;
using parent::parent;
};
class InvalidMessageTypeError : public std::invalid_argument
{
public:
using parent = std::invalid_argument;
using parent::parent;
};
class MissingNodeError : public std::out_of_range
{
public:
using parent = std::out_of_range;
using parent::parent;
};
class MissingStateError : public std::out_of_range
{
public:
using parent = std::out_of_range;
using parent::parent;
};
// A slightly more elegant way to silence -Werror=unused-variable
template<typename T>
void
silence_unused(const T& val)
{
(void)val;
}
namespace stdx {
// XXX(RLB) This method takes any container in, but always puts the resuls in
// std::vector. The output could be made generic with a Rust-like syntax,
// defining a PendingTransform object that caches the inputs, with a template
// `collect()` method that puts them in an output container. Which makes the
// calling syntax as follows:
//
// auto out = stdx::transform(in, f).collect<Container>();
//
// (You always need the explicit specialization, even if assigning it to an
// explicitly typed variable, because C++ won't infer return types.)
//
// Given that the above syntax is pretty chatty, and we never need anything
// other than vectors here anyway, I have left this as-is.
template<typename Value, typename Container, typename UnaryOperation>
std::vector<Value>
transform(const Container& c, const UnaryOperation& op)
{
auto out = std::vector<Value>{};
auto ins = std::inserter(out, out.begin());
std::transform(c.begin(), c.end(), ins, op);
return out;
}
template<typename Container, typename UnaryPredicate>
bool
any_of(const Container& c, const UnaryPredicate& pred)
{
return std::any_of(c.begin(), c.end(), pred);
}
template<typename Container, typename UnaryPredicate>
bool
all_of(const Container& c, const UnaryPredicate& pred)
{
return std::all_of(c.begin(), c.end(), pred);
}
template<typename Container, typename UnaryPredicate>
auto
count_if(const Container& c, const UnaryPredicate& pred)
{
return std::count_if(c.begin(), c.end(), pred);
}
template<typename Container, typename Value>
bool
contains(const Container& c, const Value& val)
{
return std::find(c.begin(), c.end(), val) != c.end();
}
template<typename Container, typename UnaryPredicate>
auto
find_if(Container& c, const UnaryPredicate& pred)
{
return std::find_if(c.begin(), c.end(), pred);
}
template<typename Container, typename UnaryPredicate>
auto
find_if(const Container& c, const UnaryPredicate& pred)
{
return std::find_if(c.begin(), c.end(), pred);
}
template<typename Container, typename Value>
auto
upper_bound(const Container& c, const Value& val)
{
return std::upper_bound(c.begin(), c.end(), val);
}
} // namespace stdx
} // namespace mlspp

View File

@@ -0,0 +1,380 @@
#pragma once
#include "mls/credential.h"
#include "mls/crypto.h"
#include "mls/tree_math.h"
namespace mlspp {
// enum {
// reserved(0),
// mls10(1),
// (255)
// } ProtocolVersion;
enum class ProtocolVersion : uint16_t
{
mls10 = 0x01,
};
extern const std::array<ProtocolVersion, 1> all_supported_versions;
// struct {
// ExtensionType extension_type;
// opaque extension_data<V>;
// } Extension;
struct Extension
{
using Type = uint16_t;
Type type;
bytes data;
TLS_SERIALIZABLE(type, data)
};
struct ExtensionType
{
static constexpr Extension::Type application_id = 1;
static constexpr Extension::Type ratchet_tree = 2;
static constexpr Extension::Type required_capabilities = 3;
static constexpr Extension::Type external_pub = 4;
static constexpr Extension::Type external_senders = 5;
// XXX(RLB) There is no IANA-registered type for this extension yet, so we use
// a value from the vendor-specific space
static constexpr Extension::Type sframe_parameters = 0xff02;
};
struct ExtensionList
{
std::vector<Extension> extensions;
// XXX(RLB) It would be good if this maintained extensions in order. It might
// be possible to do this automatically by changing the storage to a
// map<ExtensionType, bytes> and extending the TLS code to marshal that type.
template<typename T>
inline void add(const T& obj)
{
auto data = tls::marshal(obj);
add(T::type, std::move(data));
}
void add(Extension::Type type, bytes data);
template<typename T>
std::optional<T> find() const
{
for (const auto& ext : extensions) {
if (ext.type == T::type) {
return tls::get<T>(ext.data);
}
}
return std::nullopt;
}
bool has(uint16_t type) const;
TLS_SERIALIZABLE(extensions)
};
// enum {
// reserved(0),
// key_package(1),
// update(2),
// commit(3),
// (255)
// } LeafNodeSource;
enum struct LeafNodeSource : uint8_t
{
key_package = 1,
update = 2,
commit = 3,
};
// struct {
// ProtocolVersion versions<V>;
// CipherSuite ciphersuites<V>;
// ExtensionType extensions<V>;
// ProposalType proposals<V>;
// CredentialType credentials<V>;
// } Capabilities;
struct Capabilities
{
std::vector<ProtocolVersion> versions;
std::vector<CipherSuite::ID> cipher_suites;
std::vector<Extension::Type> extensions;
std::vector<uint16_t> proposals;
std::vector<CredentialType> credentials;
static Capabilities create_default();
bool extensions_supported(const std::vector<Extension::Type>& required) const;
bool proposals_supported(const std::vector<uint16_t>& required) const;
bool credential_supported(const Credential& credential) const;
template<typename Container>
bool credentials_supported(const Container& required) const
{
return stdx::all_of(required, [&](CredentialType type) {
return stdx::contains(credentials, type);
});
}
TLS_SERIALIZABLE(versions, cipher_suites, extensions, proposals, credentials)
};
// struct {
// uint64 not_before;
// uint64 not_after;
// } Lifetime;
struct Lifetime
{
uint64_t not_before;
uint64_t not_after;
static Lifetime create_default();
TLS_SERIALIZABLE(not_before, not_after)
};
// struct {
// HPKEPublicKey encryption_key;
// SignaturePublicKey signature_key;
// Credential credential;
// Capabilities capabilities;
//
// LeafNodeSource leaf_node_source;
// select (leaf_node_source) {
// case add:
// Lifetime lifetime;
//
// case update:
// struct {}
//
// case commit:
// opaque parent_hash<V>;
// }
//
// Extension extensions<V>;
// // SignWithLabel(., "LeafNodeTBS", LeafNodeTBS)
// opaque signature<V>;
// } LeafNode;
struct Empty
{
TLS_SERIALIZABLE()
};
struct ParentHash
{
bytes parent_hash;
TLS_SERIALIZABLE(parent_hash);
};
struct LeafNodeOptions
{
std::optional<Credential> credential;
std::optional<Capabilities> capabilities;
std::optional<ExtensionList> extensions;
};
// TODO Move this to treekem.h
struct LeafNode
{
HPKEPublicKey encryption_key;
SignaturePublicKey signature_key;
Credential credential;
Capabilities capabilities;
var::variant<Lifetime, Empty, ParentHash> content;
ExtensionList extensions;
bytes signature;
LeafNode() = default;
LeafNode(const LeafNode&) = default;
LeafNode(LeafNode&&) = default;
LeafNode& operator=(const LeafNode&) = default;
LeafNode& operator=(LeafNode&&) = default;
LeafNode(CipherSuite cipher_suite,
HPKEPublicKey encryption_key_in,
SignaturePublicKey signature_key_in,
Credential credential_in,
Capabilities capabilities_in,
Lifetime lifetime_in,
ExtensionList extensions_in,
const SignaturePrivateKey& sig_priv);
LeafNode for_update(CipherSuite cipher_suite,
const bytes& group_id,
LeafIndex leaf_index,
HPKEPublicKey encryption_key,
const LeafNodeOptions& opts,
const SignaturePrivateKey& sig_priv_in) const;
LeafNode for_commit(CipherSuite cipher_suite,
const bytes& group_id,
LeafIndex leaf_index,
HPKEPublicKey encryption_key,
const bytes& parent_hash,
const LeafNodeOptions& opts,
const SignaturePrivateKey& sig_priv_in) const;
void set_capabilities(Capabilities capabilities_in);
LeafNodeSource source() const;
struct MemberBinding
{
bytes group_id;
LeafIndex leaf_index;
TLS_SERIALIZABLE(group_id, leaf_index);
};
void sign(CipherSuite cipher_suite,
const SignaturePrivateKey& sig_priv,
const std::optional<MemberBinding>& binding);
bool verify(CipherSuite cipher_suite,
const std::optional<MemberBinding>& binding) const;
bool verify_expiry(uint64_t now) const;
bool verify_extension_support(const ExtensionList& ext_list) const;
TLS_SERIALIZABLE(encryption_key,
signature_key,
credential,
capabilities,
content,
extensions,
signature)
TLS_TRAITS(tls::pass,
tls::pass,
tls::pass,
tls::pass,
tls::variant<LeafNodeSource>,
tls::pass,
tls::pass)
private:
LeafNode clone_with_options(HPKEPublicKey encryption_key,
const LeafNodeOptions& opts) const;
bytes to_be_signed(const std::optional<MemberBinding>& binding) const;
};
// Concrete extension types
struct RequiredCapabilitiesExtension
{
std::vector<Extension::Type> extensions;
std::vector<uint16_t> proposals;
static const Extension::Type type;
TLS_SERIALIZABLE(extensions, proposals)
};
struct ApplicationIDExtension
{
bytes id;
static const Extension::Type type;
TLS_SERIALIZABLE(id)
};
///
/// NodeType, ParentNode, and KeyPackage
///
// TODO move this to treekem.h
struct ParentNode
{
HPKEPublicKey public_key;
bytes parent_hash;
std::vector<LeafIndex> unmerged_leaves;
bytes hash(CipherSuite suite) const;
TLS_SERIALIZABLE(public_key, parent_hash, unmerged_leaves)
};
// TODO Move this to messages.h
// struct {
// ProtocolVersion version;
// CipherSuite cipher_suite;
// HPKEPublicKey init_key;
// LeafNode leaf_node;
// Extension extensions<V>;
// // SignWithLabel(., "KeyPackageTBS", KeyPackageTBS)
// opaque signature<V>;
// } KeyPackage;
struct KeyPackage
{
ProtocolVersion version;
CipherSuite cipher_suite;
HPKEPublicKey init_key;
LeafNode leaf_node;
ExtensionList extensions;
bytes signature;
KeyPackage();
KeyPackage(CipherSuite suite_in,
HPKEPublicKey init_key_in,
LeafNode leaf_node_in,
ExtensionList extensions_in,
const SignaturePrivateKey& sig_priv_in);
KeyPackageRef ref() const;
void sign(const SignaturePrivateKey& sig_priv);
bool verify() const;
TLS_SERIALIZABLE(version,
cipher_suite,
init_key,
leaf_node,
extensions,
signature)
private:
bytes to_be_signed() const;
};
///
/// UpdatePath
///
// struct {
// HPKEPublicKey public_key;
// HPKECiphertext encrypted_path_secret<V>;
// } UpdatePathNode;
struct UpdatePathNode
{
HPKEPublicKey public_key;
std::vector<HPKECiphertext> encrypted_path_secret;
TLS_SERIALIZABLE(public_key, encrypted_path_secret)
};
// struct {
// LeafNode leaf_node;
// UpdatePathNode nodes<V>;
// } UpdatePath;
struct UpdatePath
{
LeafNode leaf_node;
std::vector<UpdatePathNode> nodes;
TLS_SERIALIZABLE(leaf_node, nodes)
};
} // namespace mlspp
namespace mlspp::tls {
TLS_VARIANT_MAP(mlspp::LeafNodeSource,
mlspp::Lifetime,
key_package)
TLS_VARIANT_MAP(mlspp::LeafNodeSource, mlspp::Empty, update)
TLS_VARIANT_MAP(mlspp::LeafNodeSource,
mlspp::ParentHash,
commit)
} // namespace mlspp::tls

View File

@@ -0,0 +1,228 @@
#pragma once
#include <mls/common.h>
#include <mls/crypto.h>
namespace mlspp {
namespace hpke {
struct UserInfoVC;
}
// struct {
// opaque identity<0..2^16-1>;
// SignaturePublicKey public_key;
// } BasicCredential;
struct BasicCredential
{
BasicCredential() {}
BasicCredential(bytes identity_in)
: identity(std::move(identity_in))
{
}
bytes identity;
TLS_SERIALIZABLE(identity)
};
struct X509Credential
{
struct CertData
{
bytes data;
TLS_SERIALIZABLE(data)
};
X509Credential() = default;
explicit X509Credential(const std::vector<bytes>& der_chain_in);
SignatureScheme signature_scheme() const;
SignaturePublicKey public_key() const;
bool valid_for(const SignaturePublicKey& pub) const;
// TODO(rlb) This should be const or exposed via a method
std::vector<CertData> der_chain;
private:
SignaturePublicKey _public_key;
SignatureScheme _signature_scheme;
};
tls::ostream&
operator<<(tls::ostream& str, const X509Credential& obj);
tls::istream&
operator>>(tls::istream& str, X509Credential& obj);
struct UserInfoVCCredential
{
UserInfoVCCredential() = default;
explicit UserInfoVCCredential(std::string userinfo_vc_jwt_in);
std::string userinfo_vc_jwt;
bool valid_for(const SignaturePublicKey& pub) const;
bool valid_from(const PublicJWK& pub) const;
friend tls::ostream operator<<(tls::ostream& str,
const UserInfoVCCredential& obj);
friend tls::istream operator>>(tls::istream& str, UserInfoVCCredential& obj);
friend bool operator==(const UserInfoVCCredential& lhs,
const UserInfoVCCredential& rhs);
friend bool operator!=(const UserInfoVCCredential& lhs,
const UserInfoVCCredential& rhs);
private:
std::shared_ptr<hpke::UserInfoVC> _vc;
};
bool
operator==(const X509Credential& lhs, const X509Credential& rhs);
enum struct CredentialType : uint16_t
{
reserved = 0,
basic = 1,
x509 = 2,
userinfo_vc_draft_00 = 0xFE00,
multi_draft_00 = 0xFF00,
// GREASE values, included here mainly so that debugger output looks nice
GREASE_0 = 0x0A0A,
GREASE_1 = 0x1A1A,
GREASE_2 = 0x2A2A,
GREASE_3 = 0x3A3A,
GREASE_4 = 0x4A4A,
GREASE_5 = 0x5A5A,
GREASE_6 = 0x6A6A,
GREASE_7 = 0x7A7A,
GREASE_8 = 0x8A8A,
GREASE_9 = 0x9A9A,
GREASE_A = 0xAAAA,
GREASE_B = 0xBABA,
GREASE_C = 0xCACA,
GREASE_D = 0xDADA,
GREASE_E = 0xEAEA,
};
// struct {
// Credential credential;
// SignaturePublicKey credential_key;
// opaque signature<V>;
// } CredentialBinding
//
// struct {
// CredentialBinding bindings<V>;
// } MultiCredential;
struct CredentialBinding;
struct CredentialBindingInput;
struct MultiCredential
{
MultiCredential() = default;
MultiCredential(const std::vector<CredentialBindingInput>& binding_inputs,
const SignaturePublicKey& signature_key);
std::vector<CredentialBinding> bindings;
bool valid_for(const SignaturePublicKey& pub) const;
TLS_SERIALIZABLE(bindings)
};
// struct {
// CredentialType credential_type;
// select (credential_type) {
// case basic:
// BasicCredential;
//
// case x509:
// opaque cert_data<1..2^24-1>;
// };
// } Credential;
struct Credential
{
Credential() = default;
CredentialType type() const;
template<typename T>
const T& get() const
{
return var::get<T>(_cred);
}
static Credential basic(const bytes& identity);
static Credential x509(const std::vector<bytes>& der_chain);
static Credential userinfo_vc(const std::string& userinfo_vc_jwt);
static Credential multi(
const std::vector<CredentialBindingInput>& binding_inputs,
const SignaturePublicKey& signature_key);
bool valid_for(const SignaturePublicKey& pub) const;
TLS_SERIALIZABLE(_cred)
TLS_TRAITS(tls::variant<CredentialType>)
private:
using SpecificCredential = var::variant<BasicCredential,
X509Credential,
UserInfoVCCredential,
MultiCredential>;
Credential(SpecificCredential specific);
SpecificCredential _cred;
};
// XXX(RLB): This struct needs to appear below Credential so that all types are
// concrete at the appropriate points.
struct CredentialBindingInput
{
CipherSuite cipher_suite;
Credential credential;
const SignaturePrivateKey& credential_priv;
};
struct CredentialBinding
{
CipherSuite cipher_suite;
Credential credential;
SignaturePublicKey credential_key;
bytes signature;
CredentialBinding() = default;
CredentialBinding(CipherSuite suite_in,
Credential credential_in,
const SignaturePrivateKey& credential_priv,
const SignaturePublicKey& signature_key);
bool valid_for(const SignaturePublicKey& signature_key) const;
TLS_SERIALIZABLE(cipher_suite, credential, credential_key, signature)
private:
bytes to_be_signed(const SignaturePublicKey& signature_key) const;
};
} // namespace mlspp
namespace mlspp::tls {
TLS_VARIANT_MAP(mlspp::CredentialType,
mlspp::BasicCredential,
basic)
TLS_VARIANT_MAP(mlspp::CredentialType,
mlspp::X509Credential,
x509)
TLS_VARIANT_MAP(mlspp::CredentialType,
mlspp::UserInfoVCCredential,
userinfo_vc_draft_00)
TLS_VARIANT_MAP(mlspp::CredentialType,
mlspp::MultiCredential,
multi_draft_00)
} // namespace mlspp::tls

266
DPP/mlspp/include/mls/crypto.h Executable file
View File

@@ -0,0 +1,266 @@
#pragma once
#include <hpke/digest.h>
#include <hpke/hpke.h>
#include <hpke/random.h>
#include <hpke/signature.h>
#include <mls/common.h>
#include <tls/tls_syntax.h>
#include <vector>
namespace mlspp {
/// Signature Code points, borrowed from RFC 8446
enum struct SignatureScheme : uint16_t
{
ecdsa_secp256r1_sha256 = 0x0403,
ecdsa_secp384r1_sha384 = 0x0805,
ecdsa_secp521r1_sha512 = 0x0603,
ed25519 = 0x0807,
ed448 = 0x0808,
rsa_pkcs1_sha256 = 0x0401,
};
SignatureScheme
tls_signature_scheme(hpke::Signature::ID id);
/// Cipher suites
struct KeyAndNonce
{
bytes key;
bytes nonce;
};
// opaque HashReference<V>;
// HashReference KeyPackageRef;
// HashReference ProposalRef;
using HashReference = bytes;
using KeyPackageRef = HashReference;
using ProposalRef = HashReference;
struct CipherSuite
{
enum struct ID : uint16_t
{
unknown = 0x0000,
X25519_AES128GCM_SHA256_Ed25519 = 0x0001,
P256_AES128GCM_SHA256_P256 = 0x0002,
X25519_CHACHA20POLY1305_SHA256_Ed25519 = 0x0003,
X448_AES256GCM_SHA512_Ed448 = 0x0004,
P521_AES256GCM_SHA512_P521 = 0x0005,
X448_CHACHA20POLY1305_SHA512_Ed448 = 0x0006,
P384_AES256GCM_SHA384_P384 = 0x0007,
// GREASE values, included here mainly so that debugger output looks nice
GREASE_0 = 0x0A0A,
GREASE_1 = 0x1A1A,
GREASE_2 = 0x2A2A,
GREASE_3 = 0x3A3A,
GREASE_4 = 0x4A4A,
GREASE_5 = 0x5A5A,
GREASE_6 = 0x6A6A,
GREASE_7 = 0x7A7A,
GREASE_8 = 0x8A8A,
GREASE_9 = 0x9A9A,
GREASE_A = 0xAAAA,
GREASE_B = 0xBABA,
GREASE_C = 0xCACA,
GREASE_D = 0xDADA,
GREASE_E = 0xEAEA,
};
CipherSuite();
CipherSuite(ID id_in);
ID cipher_suite() const { return id; }
SignatureScheme signature_scheme() const;
size_t secret_size() const { return get().digest.hash_size; }
size_t key_size() const { return get().hpke.aead.key_size; }
size_t nonce_size() const { return get().hpke.aead.nonce_size; }
bytes zero() const { return bytes(secret_size(), 0); }
const hpke::HPKE& hpke() const { return get().hpke; }
const hpke::Digest& digest() const { return get().digest; }
const hpke::Signature& sig() const { return get().sig; }
bytes expand_with_label(const bytes& secret,
const std::string& label,
const bytes& context,
size_t length) const;
bytes derive_secret(const bytes& secret, const std::string& label) const;
bytes derive_tree_secret(const bytes& secret,
const std::string& label,
uint32_t generation,
size_t length) const;
template<typename T>
bytes ref(const T& value) const
{
return raw_ref(reference_label<T>(), tls::marshal(value));
}
bytes raw_ref(const bytes& label, const bytes& value) const
{
// RefHash(label, value) = Hash(RefHashInput)
//
// struct {
// opaque label<V>;
// opaque value<V>;
// } RefHashInput;
auto w = tls::ostream();
w << label << value;
return digest().hash(w.bytes());
}
TLS_SERIALIZABLE(id)
private:
ID id;
struct Ciphers
{
hpke::HPKE hpke;
const hpke::Digest& digest;
const hpke::Signature& sig;
};
const Ciphers& get() const;
template<typename T>
static const bytes& reference_label();
};
#if WITH_BORINGSSL
extern const std::array<CipherSuite::ID, 5> all_supported_suites;
#else
extern const std::array<CipherSuite::ID, 7> all_supported_suites;
#endif
// Utilities
using mlspp::hpke::random_bytes;
// HPKE Keys
namespace encrypt_label {
extern const std::string update_path_node;
extern const std::string welcome;
} // namespace encrypt_label
struct HPKECiphertext
{
bytes kem_output;
bytes ciphertext;
TLS_SERIALIZABLE(kem_output, ciphertext)
};
struct HPKEPublicKey
{
bytes data;
HPKECiphertext encrypt(CipherSuite suite,
const std::string& label,
const bytes& context,
const bytes& pt) const;
std::tuple<bytes, bytes> do_export(CipherSuite suite,
const bytes& info,
const std::string& label,
size_t size) const;
TLS_SERIALIZABLE(data)
};
struct HPKEPrivateKey
{
static HPKEPrivateKey generate(CipherSuite suite);
static HPKEPrivateKey parse(CipherSuite suite, const bytes& data);
static HPKEPrivateKey derive(CipherSuite suite, const bytes& secret);
HPKEPrivateKey() = default;
bytes data;
HPKEPublicKey public_key;
bytes decrypt(CipherSuite suite,
const std::string& label,
const bytes& context,
const HPKECiphertext& ct) const;
bytes do_export(CipherSuite suite,
const bytes& info,
const bytes& kem_output,
const std::string& label,
size_t size) const;
void set_public_key(CipherSuite suite);
TLS_SERIALIZABLE(data)
private:
HPKEPrivateKey(bytes priv_data, bytes pub_data);
};
// Signature Keys
namespace sign_label {
extern const std::string mls_content;
extern const std::string leaf_node;
extern const std::string key_package;
extern const std::string group_info;
extern const std::string multi_credential;
} // namespace sign_label
struct SignaturePublicKey
{
static SignaturePublicKey from_jwk(CipherSuite suite,
const std::string& json_str);
bytes data;
bool verify(const CipherSuite& suite,
const std::string& label,
const bytes& message,
const bytes& signature) const;
std::string to_jwk(CipherSuite suite) const;
TLS_SERIALIZABLE(data)
};
struct PublicJWK
{
SignatureScheme signature_scheme;
std::optional<std::string> key_id;
SignaturePublicKey public_key;
static PublicJWK parse(const std::string& jwk_json);
};
struct SignaturePrivateKey
{
static SignaturePrivateKey generate(CipherSuite suite);
static SignaturePrivateKey parse(CipherSuite suite, const bytes& data);
static SignaturePrivateKey derive(CipherSuite suite, const bytes& secret);
static SignaturePrivateKey from_jwk(CipherSuite suite,
const std::string& json_str);
SignaturePrivateKey() = default;
bytes data;
SignaturePublicKey public_key;
bytes sign(const CipherSuite& suite,
const std::string& label,
const bytes& message) const;
void set_public_key(CipherSuite suite);
std::string to_jwk(CipherSuite suite) const;
TLS_SERIALIZABLE(data)
private:
SignaturePrivateKey(bytes priv_data, bytes pub_data);
};
} // namespace mlspp

View File

@@ -0,0 +1,205 @@
#pragma once
#include <map>
#include <mls/common.h>
#include <mls/crypto.h>
#include <mls/messages.h>
#include <mls/tree_math.h>
namespace mlspp {
struct HashRatchet
{
CipherSuite suite;
bytes next_secret;
uint32_t next_generation;
std::map<uint32_t, KeyAndNonce> cache;
size_t key_size;
size_t nonce_size;
size_t secret_size;
// These defaults are necessary for use with containers
HashRatchet() = default;
HashRatchet(const HashRatchet& other) = default;
HashRatchet(HashRatchet&& other) = default;
HashRatchet& operator=(const HashRatchet& other) = default;
HashRatchet& operator=(HashRatchet&& other) = default;
HashRatchet(CipherSuite suite_in, bytes base_secret_in);
std::tuple<uint32_t, KeyAndNonce> next();
KeyAndNonce get(uint32_t generation);
void erase(uint32_t generation);
};
struct SecretTree
{
SecretTree() = default;
SecretTree(CipherSuite suite_in,
LeafCount group_size_in,
bytes encryption_secret_in);
bool has_leaf(LeafIndex sender) { return sender < group_size; }
bytes get(LeafIndex sender);
private:
CipherSuite suite;
LeafCount group_size;
NodeIndex root;
std::map<NodeIndex, bytes> secrets;
size_t secret_size;
};
using ReuseGuard = std::array<uint8_t, 4>;
struct GroupKeySource
{
enum struct RatchetType
{
handshake,
application,
};
GroupKeySource() = default;
GroupKeySource(CipherSuite suite_in,
LeafCount group_size,
bytes encryption_secret);
bool has_leaf(LeafIndex sender) { return secret_tree.has_leaf(sender); }
std::tuple<uint32_t, ReuseGuard, KeyAndNonce> next(ContentType content_type,
LeafIndex sender);
KeyAndNonce get(ContentType content_type,
LeafIndex sender,
uint32_t generation,
ReuseGuard reuse_guard);
void erase(ContentType type, LeafIndex sender, uint32_t generation);
private:
CipherSuite suite;
SecretTree secret_tree;
using Key = std::tuple<RatchetType, LeafIndex>;
std::map<Key, HashRatchet> chains;
HashRatchet& chain(RatchetType type, LeafIndex sender);
HashRatchet& chain(ContentType type, LeafIndex sender);
static const std::array<RatchetType, 2> all_ratchet_types;
};
struct KeyScheduleEpoch
{
private:
CipherSuite suite;
public:
bytes joiner_secret;
bytes epoch_secret;
bytes sender_data_secret;
bytes encryption_secret;
bytes exporter_secret;
bytes epoch_authenticator;
bytes external_secret;
bytes confirmation_key;
bytes membership_key;
bytes resumption_psk;
bytes init_secret;
HPKEPrivateKey external_priv;
KeyScheduleEpoch() = default;
// Full initializer, used by invited joiner
static KeyScheduleEpoch joiner(CipherSuite suite_in,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks,
const bytes& context);
// Ciphersuite-only initializer, used by external joiner
KeyScheduleEpoch(CipherSuite suite_in);
// Initial epoch
KeyScheduleEpoch(CipherSuite suite_in,
const bytes& init_secret,
const bytes& context);
static std::tuple<bytes, bytes> external_init(
CipherSuite suite,
const HPKEPublicKey& external_pub);
bytes receive_external_init(const bytes& kem_output) const;
KeyScheduleEpoch next(const bytes& commit_secret,
const std::vector<PSKWithSecret>& psks,
const std::optional<bytes>& force_init_secret,
const bytes& context) const;
GroupKeySource encryption_keys(LeafCount size) const;
bytes confirmation_tag(const bytes& confirmed_transcript_hash) const;
bytes do_export(const std::string& label,
const bytes& context,
size_t size) const;
PSKWithSecret resumption_psk_w_secret(ResumptionPSKUsage usage,
const bytes& group_id,
epoch_t epoch);
static bytes make_psk_secret(CipherSuite suite,
const std::vector<PSKWithSecret>& psks);
static bytes welcome_secret(CipherSuite suite,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks);
static KeyAndNonce sender_data_keys(CipherSuite suite,
const bytes& sender_data_secret,
const bytes& ciphertext);
// TODO(RLB) make these methods private, but accessible to test vectors
KeyScheduleEpoch(CipherSuite suite_in,
const bytes& init_secret,
const bytes& commit_secret,
const bytes& psk_secret,
const bytes& context);
KeyScheduleEpoch next_raw(const bytes& commit_secret,
const bytes& psk_secret,
const std::optional<bytes>& force_init_secret,
const bytes& context) const;
static bytes welcome_secret_raw(CipherSuite suite,
const bytes& joiner_secret,
const bytes& psk_secret);
private:
KeyScheduleEpoch(CipherSuite suite_in,
const bytes& joiner_secret,
const bytes& psk_secret,
const bytes& context);
};
bool
operator==(const KeyScheduleEpoch& lhs, const KeyScheduleEpoch& rhs);
struct TranscriptHash
{
CipherSuite suite;
bytes confirmed;
bytes interim;
// For a new group
TranscriptHash(CipherSuite suite_in);
// For joining a group
TranscriptHash(CipherSuite suite_in,
bytes confirmed_in,
const bytes& confirmation_tag);
void update(const AuthenticatedContent& content_auth);
void update_confirmed(const AuthenticatedContent& content_auth);
void update_interim(const bytes& confirmation_tag);
void update_interim(const AuthenticatedContent& content_auth);
};
bool
operator==(const TranscriptHash& lhs, const TranscriptHash& rhs);
} // namespace mlspp

752
DPP/mlspp/include/mls/messages.h Executable file
View File

@@ -0,0 +1,752 @@
#pragma once
#include "mls/common.h"
#include "mls/core_types.h"
#include "mls/credential.h"
#include "mls/crypto.h"
#include "mls/treekem.h"
#include <optional>
#include <tls/tls_syntax.h>
namespace mlspp {
struct ExternalPubExtension
{
HPKEPublicKey external_pub;
static const uint16_t type;
TLS_SERIALIZABLE(external_pub)
};
struct RatchetTreeExtension
{
TreeKEMPublicKey tree;
static const uint16_t type;
TLS_SERIALIZABLE(tree)
};
struct ExternalSender
{
SignaturePublicKey signature_key;
Credential credential;
TLS_SERIALIZABLE(signature_key, credential);
};
struct ExternalSendersExtension
{
std::vector<ExternalSender> senders;
static const uint16_t type;
TLS_SERIALIZABLE(senders);
};
struct SFrameParameters
{
uint16_t cipher_suite;
uint8_t epoch_bits;
static const uint16_t type;
TLS_SERIALIZABLE(cipher_suite, epoch_bits)
};
struct SFrameCapabilities
{
std::vector<uint16_t> cipher_suites;
bool compatible(const SFrameParameters& params) const;
static const uint16_t type;
TLS_SERIALIZABLE(cipher_suites)
};
///
/// PSKs
///
enum struct PSKType : uint8_t
{
reserved = 0,
external = 1,
resumption = 2,
};
struct ExternalPSK
{
bytes psk_id;
TLS_SERIALIZABLE(psk_id)
};
enum struct ResumptionPSKUsage : uint8_t
{
reserved = 0,
application = 1,
reinit = 2,
branch = 3,
};
struct ResumptionPSK
{
ResumptionPSKUsage usage;
bytes psk_group_id;
epoch_t psk_epoch;
TLS_SERIALIZABLE(usage, psk_group_id, psk_epoch)
};
struct PreSharedKeyID
{
var::variant<ExternalPSK, ResumptionPSK> content;
bytes psk_nonce;
TLS_SERIALIZABLE(content, psk_nonce)
TLS_TRAITS(tls::variant<PSKType>, tls::pass)
};
struct PreSharedKeys
{
std::vector<PreSharedKeyID> psks;
TLS_SERIALIZABLE(psks)
};
struct PSKWithSecret
{
PreSharedKeyID id;
bytes secret;
};
// struct {
// ProtocolVersion version = mls10;
// CipherSuite cipher_suite;
// opaque group_id<V>;
// uint64 epoch;
// opaque tree_hash<V>;
// opaque confirmed_transcript_hash<V>;
// Extension extensions<V>;
// } GroupContext;
struct GroupContext
{
ProtocolVersion version{ ProtocolVersion::mls10 };
CipherSuite cipher_suite;
bytes group_id;
epoch_t epoch;
bytes tree_hash;
bytes confirmed_transcript_hash;
ExtensionList extensions;
GroupContext() = default;
GroupContext(CipherSuite cipher_suite_in,
bytes group_id_in,
epoch_t epoch_in,
bytes tree_hash_in,
bytes confirmed_transcript_hash_in,
ExtensionList extensions_in);
TLS_SERIALIZABLE(version,
cipher_suite,
group_id,
epoch,
tree_hash,
confirmed_transcript_hash,
extensions)
};
// struct {
// GroupContext group_context;
// Extension extensions<V>;
// MAC confirmation_tag;
// uint32 signer;
// // SignWithLabel(., "GroupInfoTBS", GroupInfoTBS)
// opaque signature<V>;
// } GroupInfo;
struct GroupInfo
{
GroupContext group_context;
ExtensionList extensions;
bytes confirmation_tag;
LeafIndex signer;
bytes signature;
GroupInfo() = default;
GroupInfo(GroupContext group_context_in,
ExtensionList extensions_in,
bytes confirmation_tag_in);
bytes to_be_signed() const;
void sign(const TreeKEMPublicKey& tree,
LeafIndex signer_index,
const SignaturePrivateKey& priv);
bool verify(const TreeKEMPublicKey& tree) const;
// These methods exist only to simplify unit testing
void sign(LeafIndex signer_index, const SignaturePrivateKey& priv);
bool verify(const SignaturePublicKey& pub) const;
TLS_SERIALIZABLE(group_context,
extensions,
confirmation_tag,
signer,
signature)
};
// struct {
// opaque joiner_secret<1..255>;
// optional<PathSecret> path_secret;
// PreSharedKeys psks;
// } GroupSecrets;
struct GroupSecrets
{
struct PathSecret
{
bytes secret;
TLS_SERIALIZABLE(secret)
};
bytes joiner_secret;
std::optional<PathSecret> path_secret;
PreSharedKeys psks;
TLS_SERIALIZABLE(joiner_secret, path_secret, psks)
};
// struct {
// opaque key_package_hash<1..255>;
// HPKECiphertext encrypted_group_secrets;
// } EncryptedGroupSecrets;
struct EncryptedGroupSecrets
{
KeyPackageRef new_member;
HPKECiphertext encrypted_group_secrets;
TLS_SERIALIZABLE(new_member, encrypted_group_secrets)
};
// struct {
// ProtocolVersion version = mls10;
// CipherSuite cipher_suite;
// EncryptedGroupSecrets group_secretss<1..2^32-1>;
// opaque encrypted_group_info<1..2^32-1>;
// } Welcome;
struct Welcome
{
CipherSuite cipher_suite;
std::vector<EncryptedGroupSecrets> secrets;
bytes encrypted_group_info;
Welcome();
Welcome(CipherSuite suite,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks,
const GroupInfo& group_info);
void encrypt(const KeyPackage& kp, const std::optional<bytes>& path_secret);
std::optional<int> find(const KeyPackage& kp) const;
GroupSecrets decrypt_secrets(int kp_index,
const HPKEPrivateKey& init_priv) const;
GroupInfo decrypt(const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks) const;
TLS_SERIALIZABLE(cipher_suite, secrets, encrypted_group_info)
private:
bytes _joiner_secret;
PreSharedKeys _psks;
static KeyAndNonce group_info_key_nonce(
CipherSuite suite,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks);
};
///
/// Proposals & Commit
///
// Add
struct Add
{
KeyPackage key_package;
TLS_SERIALIZABLE(key_package)
};
// Update
struct Update
{
LeafNode leaf_node;
TLS_SERIALIZABLE(leaf_node)
};
// Remove
struct Remove
{
LeafIndex removed;
TLS_SERIALIZABLE(removed)
};
// PreSharedKey
struct PreSharedKey
{
PreSharedKeyID psk;
TLS_SERIALIZABLE(psk)
};
// ReInit
struct ReInit
{
bytes group_id;
ProtocolVersion version;
CipherSuite cipher_suite;
ExtensionList extensions;
TLS_SERIALIZABLE(group_id, version, cipher_suite, extensions)
};
// ExternalInit
struct ExternalInit
{
bytes kem_output;
TLS_SERIALIZABLE(kem_output)
};
// GroupContextExtensions
struct GroupContextExtensions
{
ExtensionList group_context_extensions;
TLS_SERIALIZABLE(group_context_extensions)
};
struct ProposalType;
struct Proposal
{
using Type = uint16_t;
var::variant<Add,
Update,
Remove,
PreSharedKey,
ReInit,
ExternalInit,
GroupContextExtensions>
content;
Type proposal_type() const;
TLS_SERIALIZABLE(content)
TLS_TRAITS(tls::variant<ProposalType>)
};
struct ProposalType
{
static constexpr Proposal::Type invalid = 0;
static constexpr Proposal::Type add = 1;
static constexpr Proposal::Type update = 2;
static constexpr Proposal::Type remove = 3;
static constexpr Proposal::Type psk = 4;
static constexpr Proposal::Type reinit = 5;
static constexpr Proposal::Type external_init = 6;
static constexpr Proposal::Type group_context_extensions = 7;
constexpr ProposalType()
: val(invalid)
{
}
constexpr ProposalType(Proposal::Type pt)
: val(pt)
{
}
Proposal::Type val;
TLS_SERIALIZABLE(val)
};
enum struct ProposalOrRefType : uint8_t
{
reserved = 0,
value = 1,
reference = 2,
};
struct ProposalOrRef
{
var::variant<Proposal, ProposalRef> content;
TLS_SERIALIZABLE(content)
TLS_TRAITS(tls::variant<ProposalOrRefType>)
};
// struct {
// ProposalOrRef proposals<0..2^32-1>;
// optional<UpdatePath> path;
// } Commit;
struct Commit
{
std::vector<ProposalOrRef> proposals;
std::optional<UpdatePath> path;
// Validate that the commit is acceptable as an external commit, and if so,
// produce the public key from the ExternalInit proposal
std::optional<bytes> valid_external() const;
TLS_SERIALIZABLE(proposals, path)
};
// struct {
// opaque group_id<0..255>;
// uint32 epoch;
// uint32 sender;
// ContentType content_type;
//
// select (PublicMessage.content_type) {
// case handshake:
// GroupOperation operation;
// opaque confirmation<0..255>;
//
// case application:
// opaque application_data<0..2^32-1>;
// }
//
// opaque signature<0..2^16-1>;
// } PublicMessage;
struct ApplicationData
{
bytes data;
TLS_SERIALIZABLE(data)
};
struct GroupContext;
enum struct WireFormat : uint16_t
{
reserved = 0,
mls_public_message = 1,
mls_private_message = 2,
mls_welcome = 3,
mls_group_info = 4,
mls_key_package = 5,
};
enum struct ContentType : uint8_t
{
invalid = 0,
application = 1,
proposal = 2,
commit = 3,
};
enum struct SenderType : uint8_t
{
invalid = 0,
member = 1,
external = 2,
new_member_proposal = 3,
new_member_commit = 4,
};
struct MemberSender
{
LeafIndex sender;
TLS_SERIALIZABLE(sender);
};
struct ExternalSenderIndex
{
uint32_t sender_index;
TLS_SERIALIZABLE(sender_index)
};
struct NewMemberProposalSender
{
TLS_SERIALIZABLE()
};
struct NewMemberCommitSender
{
TLS_SERIALIZABLE()
};
struct Sender
{
var::variant<MemberSender,
ExternalSenderIndex,
NewMemberProposalSender,
NewMemberCommitSender>
sender;
SenderType sender_type() const;
TLS_SERIALIZABLE(sender)
TLS_TRAITS(tls::variant<SenderType>)
};
///
/// MLSMessage and friends
///
struct GroupKeySource;
struct GroupContent
{
using RawContent = var::variant<ApplicationData, Proposal, Commit>;
bytes group_id;
epoch_t epoch;
Sender sender;
bytes authenticated_data;
RawContent content;
GroupContent() = default;
GroupContent(bytes group_id_in,
epoch_t epoch_in,
Sender sender_in,
bytes authenticated_data_in,
RawContent content_in);
GroupContent(bytes group_id_in,
epoch_t epoch_in,
Sender sender_in,
bytes authenticated_data_in,
ContentType content_type);
ContentType content_type() const;
TLS_SERIALIZABLE(group_id, epoch, sender, authenticated_data, content)
TLS_TRAITS(tls::pass,
tls::pass,
tls::pass,
tls::pass,
tls::variant<ContentType>)
};
struct GroupContentAuthData
{
ContentType content_type = ContentType::invalid;
bytes signature;
std::optional<bytes> confirmation_tag;
friend tls::ostream& operator<<(tls::ostream& str,
const GroupContentAuthData& obj);
friend tls::istream& operator>>(tls::istream& str, GroupContentAuthData& obj);
friend bool operator==(const GroupContentAuthData& lhs,
const GroupContentAuthData& rhs);
};
struct AuthenticatedContent
{
WireFormat wire_format;
GroupContent content;
GroupContentAuthData auth;
AuthenticatedContent() = default;
static AuthenticatedContent sign(WireFormat wire_format,
GroupContent content,
CipherSuite suite,
const SignaturePrivateKey& sig_priv,
const std::optional<GroupContext>& context);
bool verify(CipherSuite suite,
const SignaturePublicKey& sig_pub,
const std::optional<GroupContext>& context) const;
bytes confirmed_transcript_hash_input() const;
bytes interim_transcript_hash_input() const;
void set_confirmation_tag(const bytes& confirmation_tag);
bool check_confirmation_tag(const bytes& confirmation_tag) const;
friend tls::ostream& operator<<(tls::ostream& str,
const AuthenticatedContent& obj);
friend tls::istream& operator>>(tls::istream& str, AuthenticatedContent& obj);
friend bool operator==(const AuthenticatedContent& lhs,
const AuthenticatedContent& rhs);
private:
AuthenticatedContent(WireFormat wire_format_in, GroupContent content_in);
AuthenticatedContent(WireFormat wire_format_in,
GroupContent content_in,
GroupContentAuthData auth_in);
bytes to_be_signed(const std::optional<GroupContext>& context) const;
friend struct PublicMessage;
friend struct PrivateMessage;
};
struct ValidatedContent
{
const AuthenticatedContent& authenticated_content() const;
friend bool operator==(const ValidatedContent& lhs,
const ValidatedContent& rhs);
private:
AuthenticatedContent content_auth;
ValidatedContent(AuthenticatedContent content_auth_in);
friend struct PublicMessage;
friend struct PrivateMessage;
friend class State;
};
struct PublicMessage
{
PublicMessage() = default;
bytes get_group_id() const { return content.group_id; }
epoch_t get_epoch() const { return content.epoch; }
static PublicMessage protect(AuthenticatedContent content_auth,
CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context);
std::optional<ValidatedContent> unprotect(
CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context) const;
bool contains(const AuthenticatedContent& content_auth) const;
// TODO(RLB) Make this private and expose only to tests
AuthenticatedContent authenticated_content() const;
friend tls::ostream& operator<<(tls::ostream& str, const PublicMessage& obj);
friend tls::istream& operator>>(tls::istream& str, PublicMessage& obj);
friend bool operator==(const PublicMessage& lhs, const PublicMessage& rhs);
friend bool operator!=(const PublicMessage& lhs, const PublicMessage& rhs);
private:
GroupContent content;
GroupContentAuthData auth;
std::optional<bytes> membership_tag;
PublicMessage(AuthenticatedContent content_auth);
bytes membership_mac(CipherSuite suite,
const bytes& membership_key,
const std::optional<GroupContext>& context) const;
};
struct PrivateMessage
{
PrivateMessage() = default;
bytes get_group_id() const { return group_id; }
epoch_t get_epoch() const { return epoch; }
static PrivateMessage protect(AuthenticatedContent content_auth,
CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret,
size_t padding_size);
std::optional<ValidatedContent> unprotect(
CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret) const;
TLS_SERIALIZABLE(group_id,
epoch,
content_type,
authenticated_data,
encrypted_sender_data,
ciphertext)
private:
bytes group_id;
epoch_t epoch;
ContentType content_type;
bytes authenticated_data;
bytes encrypted_sender_data;
bytes ciphertext;
PrivateMessage(GroupContent content,
bytes encrypted_sender_data_in,
bytes ciphertext_in);
};
struct MLSMessage
{
ProtocolVersion version = ProtocolVersion::mls10;
var::variant<PublicMessage, PrivateMessage, Welcome, GroupInfo, KeyPackage>
message;
bytes group_id() const;
epoch_t epoch() const;
WireFormat wire_format() const;
MLSMessage() = default;
MLSMessage(PublicMessage public_message);
MLSMessage(PrivateMessage private_message);
MLSMessage(Welcome welcome);
MLSMessage(GroupInfo group_info);
MLSMessage(KeyPackage key_package);
TLS_SERIALIZABLE(version, message)
TLS_TRAITS(tls::pass, tls::variant<WireFormat>)
};
MLSMessage
external_proposal(CipherSuite suite,
const bytes& group_id,
epoch_t epoch,
const Proposal& proposal,
uint32_t signer_index,
const SignaturePrivateKey& sig_priv);
} // namespace mlspp
namespace mlspp::tls {
TLS_VARIANT_MAP(mlspp::PSKType, mlspp::ExternalPSK, external)
TLS_VARIANT_MAP(mlspp::PSKType,
mlspp::ResumptionPSK,
resumption)
TLS_VARIANT_MAP(mlspp::ProposalOrRefType,
mlspp::Proposal,
value)
TLS_VARIANT_MAP(mlspp::ProposalOrRefType,
mlspp::ProposalRef,
reference)
TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::Add, add)
TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::Update, update)
TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::Remove, remove)
TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::PreSharedKey, psk)
TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::ReInit, reinit)
TLS_VARIANT_MAP(mlspp::ProposalType,
mlspp::ExternalInit,
external_init)
TLS_VARIANT_MAP(mlspp::ProposalType,
mlspp::GroupContextExtensions,
group_context_extensions)
TLS_VARIANT_MAP(mlspp::ContentType,
mlspp::ApplicationData,
application)
TLS_VARIANT_MAP(mlspp::ContentType, mlspp::Proposal, proposal)
TLS_VARIANT_MAP(mlspp::ContentType, mlspp::Commit, commit)
TLS_VARIANT_MAP(mlspp::SenderType, mlspp::MemberSender, member)
TLS_VARIANT_MAP(mlspp::SenderType,
mlspp::ExternalSenderIndex,
external)
TLS_VARIANT_MAP(mlspp::SenderType,
mlspp::NewMemberProposalSender,
new_member_proposal)
TLS_VARIANT_MAP(mlspp::SenderType,
mlspp::NewMemberCommitSender,
new_member_commit)
TLS_VARIANT_MAP(mlspp::WireFormat,
mlspp::PublicMessage,
mls_public_message)
TLS_VARIANT_MAP(mlspp::WireFormat,
mlspp::PrivateMessage,
mls_private_message)
TLS_VARIANT_MAP(mlspp::WireFormat, mlspp::Welcome, mls_welcome)
TLS_VARIANT_MAP(mlspp::WireFormat,
mlspp::GroupInfo,
mls_group_info)
TLS_VARIANT_MAP(mlspp::WireFormat,
mlspp::KeyPackage,
mls_key_package)
} // namespace mlspp::tls

98
DPP/mlspp/include/mls/session.h Executable file
View File

@@ -0,0 +1,98 @@
#pragma once
#include <mls/common.h>
#include <mls/core_types.h>
#include <mls/credential.h>
#include <mls/crypto.h>
#include <mls/state.h>
namespace mlspp {
class PendingJoin;
class Session;
class Client
{
public:
Client(CipherSuite suite_in,
SignaturePrivateKey sig_priv_in,
Credential cred_in);
Session begin_session(const bytes& group_id) const;
PendingJoin start_join() const;
private:
const CipherSuite suite;
const SignaturePrivateKey sig_priv;
const Credential cred;
};
class PendingJoin
{
public:
PendingJoin(PendingJoin&& other) noexcept;
PendingJoin& operator=(PendingJoin&& other) noexcept;
~PendingJoin();
bytes key_package() const;
Session complete(const bytes& welcome) const;
private:
struct Inner;
std::unique_ptr<Inner> inner;
PendingJoin(Inner* inner);
friend class Client;
};
class Session
{
public:
Session(Session&& other) noexcept;
Session& operator=(Session&& other) noexcept;
~Session();
// Settings
void encrypt_handshake(bool enabled);
// Message producers
bytes add(const bytes& key_package_data);
bytes update();
bytes remove(uint32_t index);
std::tuple<bytes, bytes> commit(const bytes& proposal);
std::tuple<bytes, bytes> commit(const std::vector<bytes>& proposals);
std::tuple<bytes, bytes> commit();
// Message consumers
bool handle(const bytes& handshake_data);
// Information about the current state
epoch_t epoch() const;
LeafIndex index() const;
CipherSuite cipher_suite() const;
const ExtensionList& extensions() const;
const TreeKEMPublicKey& tree() const;
bytes do_export(const std::string& label,
const bytes& context,
size_t size) const;
GroupInfo group_info() const;
std::vector<LeafNode> roster() const;
bytes epoch_authenticator() const;
// Application message protection
bytes protect(const bytes& plaintext);
bytes unprotect(const bytes& ciphertext);
protected:
struct Inner;
std::unique_ptr<Inner> inner;
Session(Inner* inner);
friend class Client;
friend class PendingJoin;
friend bool operator==(const Session& lhs, const Session& rhs);
friend bool operator!=(const Session& lhs, const Session& rhs);
};
} // namespace mlspp

431
DPP/mlspp/include/mls/state.h Executable file
View File

@@ -0,0 +1,431 @@
#pragma once
#include "mls/crypto.h"
#include "mls/key_schedule.h"
#include "mls/messages.h"
#include "mls/treekem.h"
#include <list>
#include <optional>
#include <vector>
namespace mlspp {
// Index into the session roster
struct RosterIndex : public UInt32
{
using UInt32::UInt32;
};
struct CommitOpts
{
std::vector<Proposal> extra_proposals;
bool inline_tree;
bool force_path;
LeafNodeOptions leaf_node_opts;
};
struct MessageOpts
{
bool encrypt = false;
bytes authenticated_data;
size_t padding_size = 0;
};
class State
{
public:
///
/// Constructors
///
// Initialize an empty group
State(bytes group_id,
CipherSuite suite,
HPKEPrivateKey enc_priv,
SignaturePrivateKey sig_priv,
const LeafNode& leaf_node,
ExtensionList extensions);
// Initialize a group from a Welcome
State(const HPKEPrivateKey& init_priv,
HPKEPrivateKey leaf_priv,
SignaturePrivateKey sig_priv,
const KeyPackage& key_package,
const Welcome& welcome,
const std::optional<TreeKEMPublicKey>& tree,
std::map<bytes, bytes> psks);
// Join a group from outside
// XXX(RLB) To be fully general, we would need a few more options here, e.g.,
// whether to include PSKs or evict our prior appearance.
static std::tuple<MLSMessage, State> external_join(
const bytes& leaf_secret,
SignaturePrivateKey sig_priv,
const KeyPackage& key_package,
const GroupInfo& group_info,
const std::optional<TreeKEMPublicKey>& tree,
const MessageOpts& msg_opts,
std::optional<LeafIndex> remove_prior,
const std::map<bytes, bytes>& psks);
// Propose that a new member be added a group
static MLSMessage new_member_add(const bytes& group_id,
epoch_t epoch,
const KeyPackage& new_member,
const SignaturePrivateKey& sig_priv);
///
/// Message factories
///
Proposal add_proposal(const KeyPackage& key_package) const;
Proposal update_proposal(HPKEPrivateKey leaf_priv,
const LeafNodeOptions& opts);
Proposal remove_proposal(RosterIndex index) const;
Proposal remove_proposal(LeafIndex removed) const;
Proposal group_context_extensions_proposal(ExtensionList exts) const;
Proposal pre_shared_key_proposal(const bytes& external_psk_id) const;
Proposal pre_shared_key_proposal(const bytes& group_id, epoch_t epoch) const;
static Proposal reinit_proposal(bytes group_id,
ProtocolVersion version,
CipherSuite cipher_suite,
ExtensionList extensions);
MLSMessage add(const KeyPackage& key_package, const MessageOpts& msg_opts);
MLSMessage update(HPKEPrivateKey leaf_priv,
const LeafNodeOptions& opts,
const MessageOpts& msg_opts);
MLSMessage remove(RosterIndex index, const MessageOpts& msg_opts);
MLSMessage remove(LeafIndex removed, const MessageOpts& msg_opts);
MLSMessage group_context_extensions(ExtensionList exts,
const MessageOpts& msg_opts);
MLSMessage pre_shared_key(const bytes& external_psk_id,
const MessageOpts& msg_opts);
MLSMessage pre_shared_key(const bytes& group_id,
epoch_t epoch,
const MessageOpts& msg_opts);
MLSMessage reinit(bytes group_id,
ProtocolVersion version,
CipherSuite cipher_suite,
ExtensionList extensions,
const MessageOpts& msg_opts);
std::tuple<MLSMessage, Welcome, State> commit(
const bytes& leaf_secret,
const std::optional<CommitOpts>& opts,
const MessageOpts& msg_opts);
///
/// Generic handshake message handlers
///
std::optional<State> handle(const MLSMessage& msg);
std::optional<State> handle(const MLSMessage& msg,
std::optional<State> cached_state);
std::optional<State> handle(const ValidatedContent& content_auth);
std::optional<State> handle(const ValidatedContent& content_auth,
std::optional<State> cached_state);
///
/// PSK management
///
void add_resumption_psk(const bytes& group_id, epoch_t epoch, bytes secret);
void remove_resumption_psk(const bytes& group_id, epoch_t epoch);
void add_external_psk(const bytes& id, const bytes& secret);
void remove_external_psk(const bytes& id);
///
/// Accessors
///
const bytes& group_id() const { return _group_id; }
epoch_t epoch() const { return _epoch; }
LeafIndex index() const { return _index; }
CipherSuite cipher_suite() const { return _suite; }
const ExtensionList& extensions() const { return _extensions; }
const TreeKEMPublicKey& tree() const { return _tree; }
const bytes& resumption_psk() const { return _key_schedule.resumption_psk; }
bytes do_export(const std::string& label,
const bytes& context,
size_t size) const;
GroupInfo group_info(bool inline_tree) const;
// Ordered list of credentials from non-blank leaves
std::vector<LeafNode> roster() const;
bytes epoch_authenticator() const;
///
/// Unwrap messages so that applications can inspect them
///
ValidatedContent unwrap(const MLSMessage& msg);
///
/// Application encryption and decryption
///
MLSMessage protect(const bytes& authenticated_data,
const bytes& pt,
size_t padding_size);
std::tuple<bytes, bytes> unprotect(const MLSMessage& ct);
// Assemble a group context for this state
GroupContext group_context() const;
// Subgroup branching
std::tuple<State, Welcome> create_branch(
bytes group_id,
HPKEPrivateKey enc_priv,
SignaturePrivateKey sig_priv,
const LeafNode& leaf_node,
ExtensionList extensions,
const std::vector<KeyPackage>& key_packages,
const bytes& leaf_secret,
const CommitOpts& commit_opts) const;
State handle_branch(const HPKEPrivateKey& init_priv,
HPKEPrivateKey enc_priv,
SignaturePrivateKey sig_priv,
const KeyPackage& key_package,
const Welcome& welcome,
const std::optional<TreeKEMPublicKey>& tree) const;
// Reinitialization
struct Tombstone
{
std::tuple<State, Welcome> create_welcome(
HPKEPrivateKey enc_priv,
SignaturePrivateKey sig_priv,
const LeafNode& leaf_node,
const std::vector<KeyPackage>& key_packages,
const bytes& leaf_secret,
const CommitOpts& commit_opts) const;
State handle_welcome(const HPKEPrivateKey& init_priv,
HPKEPrivateKey enc_priv,
SignaturePrivateKey sig_priv,
const KeyPackage& key_package,
const Welcome& welcome,
const std::optional<TreeKEMPublicKey>& tree) const;
TLS_SERIALIZABLE(prior_group_id, prior_epoch, resumption_psk, reinit);
const bytes epoch_authenticator;
const ReInit reinit;
private:
Tombstone(const State& state_in, ReInit reinit_in);
bytes prior_group_id;
epoch_t prior_epoch;
bytes resumption_psk;
friend class State;
};
std::tuple<Tombstone, MLSMessage> reinit_commit(
const bytes& leaf_secret,
const std::optional<CommitOpts>& opts,
const MessageOpts& msg_opts);
Tombstone handle_reinit_commit(const MLSMessage& commit);
protected:
// Shared confirmed state
// XXX(rlb@ipv.sx): Can these be made const?
CipherSuite _suite;
bytes _group_id;
epoch_t _epoch;
TreeKEMPublicKey _tree;
TreeKEMPrivateKey _tree_priv;
TranscriptHash _transcript_hash;
ExtensionList _extensions;
// Shared secret state
KeyScheduleEpoch _key_schedule;
GroupKeySource _keys;
// Per-participant state
LeafIndex _index;
SignaturePrivateKey _identity_priv;
// Storage for PSKs
std::map<bytes, bytes> _external_psks;
using EpochRef = std::tuple<bytes, epoch_t>;
std::map<EpochRef, bytes> _resumption_psks;
// Cache of Proposals and update secrets
struct CachedProposal
{
ProposalRef ref;
Proposal proposal;
std::optional<LeafIndex> sender;
};
std::list<CachedProposal> _pending_proposals;
struct CachedUpdate
{
HPKEPrivateKey update_priv;
Update proposal;
};
std::optional<CachedUpdate> _cached_update;
// Assemble a preliminary, unjoined group state
State(SignaturePrivateKey sig_priv,
const GroupInfo& group_info,
const std::optional<TreeKEMPublicKey>& tree);
// Assemble a group from a Welcome, allowing for resumption PSKs
State(const HPKEPrivateKey& init_priv,
HPKEPrivateKey leaf_priv,
SignaturePrivateKey sig_priv,
const KeyPackage& key_package,
const Welcome& welcome,
const std::optional<TreeKEMPublicKey>& tree,
std::map<bytes, bytes> external_psks,
std::map<EpochRef, bytes> resumption_psks);
// Import a tree from an externally-provided tree or an extension
TreeKEMPublicKey import_tree(const bytes& tree_hash,
const std::optional<TreeKEMPublicKey>& external,
const ExtensionList& extensions);
bool validate_tree() const;
// Form a commit, covering all the cases with slightly different validation
// rules:
// * Normal
// * External
// * Branch
// * Reinit
struct NormalCommitParams
{};
struct ExternalCommitParams
{
KeyPackage joiner_key_package;
bytes force_init_secret;
};
struct RestartCommitParams
{
ResumptionPSKUsage allowed_usage;
};
struct ReInitCommitParams
{};
using CommitParams = var::variant<NormalCommitParams,
ExternalCommitParams,
RestartCommitParams,
ReInitCommitParams>;
std::tuple<MLSMessage, Welcome, State> commit(
const bytes& leaf_secret,
const std::optional<CommitOpts>& opts,
const MessageOpts& msg_opts,
CommitParams params);
std::optional<State> handle(
const MLSMessage& msg,
std::optional<State> cached_state,
const std::optional<CommitParams>& expected_params);
std::optional<State> handle(
const ValidatedContent& val_content,
std::optional<State> cached_state,
const std::optional<CommitParams>& expected_params);
// Create an MLSMessage encapsulating some content
template<typename Inner>
AuthenticatedContent sign(const Sender& sender,
Inner&& content,
const bytes& authenticated_data,
bool encrypt) const;
MLSMessage protect(AuthenticatedContent&& content_auth, size_t padding_size);
template<typename Inner>
MLSMessage protect_full(Inner&& content, const MessageOpts& msg_opts);
// Apply the changes requested by various messages
LeafIndex apply(const Add& add);
void apply(LeafIndex target, const Update& update);
void apply(LeafIndex target,
const Update& update,
const HPKEPrivateKey& leaf_priv);
LeafIndex apply(const Remove& remove);
void apply(const GroupContextExtensions& gce);
std::vector<LeafIndex> apply(const std::vector<CachedProposal>& proposals,
Proposal::Type required_type);
std::tuple<std::vector<LeafIndex>, std::vector<PSKWithSecret>> apply(
const std::vector<CachedProposal>& proposals);
// Verify that a specific key package or all members support a given set of
// extensions
bool extensions_supported(const ExtensionList& exts) const;
// Extract proposals and PSKs from cache
void cache_proposal(AuthenticatedContent content_auth);
std::optional<CachedProposal> resolve(
const ProposalOrRef& id,
std::optional<LeafIndex> sender_index) const;
std::vector<CachedProposal> must_resolve(
const std::vector<ProposalOrRef>& ids,
std::optional<LeafIndex> sender_index) const;
std::vector<PSKWithSecret> resolve(
const std::vector<PreSharedKeyID>& psks) const;
// Check properties of proposals
bool valid(const LeafNode& leaf_node,
LeafNodeSource required_source,
std::optional<LeafIndex> index) const;
bool valid(const KeyPackage& key_package) const;
bool valid(const Add& add) const;
bool valid(LeafIndex sender, const Update& update) const;
bool valid(const Remove& remove) const;
bool valid(const PreSharedKey& psk) const;
static bool valid(const ReInit& reinit);
bool valid(const ExternalInit& external_init) const;
bool valid(const GroupContextExtensions& gce) const;
bool valid(std::optional<LeafIndex> sender, const Proposal& proposal) const;
bool valid(const std::vector<CachedProposal>& proposals,
LeafIndex commit_sender,
const CommitParams& params) const;
bool valid_normal(const std::vector<CachedProposal>& proposals,
LeafIndex commit_sender) const;
bool valid_external(const std::vector<CachedProposal>& proposals) const;
static bool valid_reinit(const std::vector<CachedProposal>& proposals);
static bool valid_restart(const std::vector<CachedProposal>& proposals,
ResumptionPSKUsage allowed_usage);
static bool valid_external_proposal_type(const Proposal::Type proposal_type);
CommitParams infer_commit_type(
const std::optional<LeafIndex>& sender,
const std::vector<CachedProposal>& proposals,
const std::optional<CommitParams>& expected_params) const;
static bool path_required(const std::vector<CachedProposal>& proposals);
// Compare the **shared** attributes of the states
friend bool operator==(const State& lhs, const State& rhs);
friend bool operator!=(const State& lhs, const State& rhs);
// Derive and set the secrets for an epoch, given some new entropy
void update_epoch_secrets(const bytes& commit_secret,
const std::vector<PSKWithSecret>& psks,
const std::optional<bytes>& force_init_secret);
// Signature verification over a handshake message
bool verify_internal(const AuthenticatedContent& content_auth) const;
bool verify_external(const AuthenticatedContent& content_auth) const;
bool verify_new_member_proposal(
const AuthenticatedContent& content_auth) const;
bool verify_new_member_commit(const AuthenticatedContent& content_auth) const;
bool verify(const AuthenticatedContent& content_auth) const;
// Convert a Roster entry into LeafIndex
LeafIndex leaf_for_roster_entry(RosterIndex index) const;
// Create a draft successor state
State successor() const;
};
} // namespace mlspp

107
DPP/mlspp/include/mls/tree_math.h Executable file
View File

@@ -0,0 +1,107 @@
#pragma once
#include <cstdint>
#include <tls/tls_syntax.h>
#include <vector>
// The below functions provide the index calculus for the tree
// structures used in MLS. They are premised on a "flat"
// representation of a balanced binary tree. Leaf nodes are
// even-numbered nodes, with the n-th leaf at 2*n. Intermediate
// nodes are held in odd-numbered nodes. For example, a 11-element
// tree has the following structure:
//
// X
// X
// X X X
// X X X X X
// X X X X X X X X X X X
// 0 1 2 3 4 5 6 7 8 9 a b c d e f 10 11 12 13 14
//
// This allows us to compute relationships between tree nodes simply
// by manipulating indices, rather than having to maintain
// complicated structures in memory, even for partial trees. (The
// storage for a tree can just be a map[int]Node dictionary or an
// array.) The basic rule is that the high-order bits of parent and
// child nodes have the following relation:
//
// 01x = <00x, 10x>
namespace mlspp {
// Index types go in the overall namespace
// XXX(rlb@ipv.sx): Seems like this stuff can probably get
// simplified down a fair bit.
struct UInt32
{
uint32_t val;
UInt32()
: val(0)
{
}
explicit UInt32(uint32_t val_in)
: val(val_in)
{
}
TLS_SERIALIZABLE(val)
};
struct NodeCount;
struct LeafCount : public UInt32
{
using UInt32::UInt32;
explicit LeafCount(const NodeCount w);
static LeafCount full(const LeafCount n);
};
struct NodeCount : public UInt32
{
using UInt32::UInt32;
explicit NodeCount(const LeafCount n);
};
struct NodeIndex;
struct LeafIndex : public UInt32
{
using UInt32::UInt32;
explicit LeafIndex(const NodeIndex x);
bool operator<(const LeafIndex other) const { return val < other.val; }
bool operator<(const LeafCount other) const { return val < other.val; }
NodeIndex ancestor(LeafIndex other) const;
};
struct NodeIndex : public UInt32
{
using UInt32::UInt32;
explicit NodeIndex(const LeafIndex x);
bool operator<(const NodeIndex other) const { return val < other.val; }
bool operator<(const NodeCount other) const { return val < other.val; }
static NodeIndex root(LeafCount n);
bool is_leaf() const;
bool is_below(NodeIndex other) const;
NodeIndex left() const;
NodeIndex right() const;
NodeIndex parent() const;
NodeIndex sibling() const;
// Returns the sibling of this node "relative to this ancestor" -- the child
// of `ancestor` that is not in the direct path of this node.
NodeIndex sibling(NodeIndex ancestor) const;
std::vector<NodeIndex> dirpath(LeafCount n);
std::vector<NodeIndex> copath(LeafCount n);
uint32_t level() const;
};
} // namespace mlspp

255
DPP/mlspp/include/mls/treekem.h Executable file
View File

@@ -0,0 +1,255 @@
#pragma once
#include "mls/common.h"
#include "mls/core_types.h"
#include "mls/crypto.h"
#include "mls/tree_math.h"
#include <tls/tls_syntax.h>
#define ENABLE_TREE_DUMP 1
namespace mlspp {
enum struct NodeType : uint8_t
{
reserved = 0x00,
leaf = 0x01,
parent = 0x02,
};
struct Node
{
var::variant<LeafNode, ParentNode> node;
const HPKEPublicKey& public_key() const;
std::optional<bytes> parent_hash() const;
TLS_SERIALIZABLE(node)
TLS_TRAITS(tls::variant<NodeType>)
};
struct OptionalNode
{
std::optional<Node> node;
bool blank() const { return !node.has_value(); }
bool leaf() const
{
return !blank() && var::holds_alternative<LeafNode>(opt::get(node).node);
}
LeafNode& leaf_node() { return var::get<LeafNode>(opt::get(node).node); }
const LeafNode& leaf_node() const
{
return var::get<LeafNode>(opt::get(node).node);
}
ParentNode& parent_node()
{
return var::get<ParentNode>(opt::get(node).node);
}
const ParentNode& parent_node() const
{
return var::get<ParentNode>(opt::get(node).node);
}
TLS_SERIALIZABLE(node)
};
struct TreeKEMPublicKey;
struct TreeKEMPrivateKey
{
CipherSuite suite;
LeafIndex index;
bytes update_secret;
std::map<NodeIndex, bytes> path_secrets;
std::map<NodeIndex, HPKEPrivateKey> private_key_cache;
static TreeKEMPrivateKey solo(CipherSuite suite,
LeafIndex index,
HPKEPrivateKey leaf_priv);
static TreeKEMPrivateKey create(const TreeKEMPublicKey& pub,
LeafIndex from,
const bytes& leaf_secret);
static TreeKEMPrivateKey joiner(const TreeKEMPublicKey& pub,
LeafIndex index,
HPKEPrivateKey leaf_priv,
NodeIndex intersect,
const std::optional<bytes>& path_secret);
void set_leaf_priv(HPKEPrivateKey priv);
std::tuple<NodeIndex, bytes, bool> shared_path_secret(LeafIndex to) const;
bool have_private_key(NodeIndex n) const;
std::optional<HPKEPrivateKey> private_key(NodeIndex n);
std::optional<HPKEPrivateKey> private_key(NodeIndex n) const;
void decap(LeafIndex from,
const TreeKEMPublicKey& pub,
const bytes& context,
const UpdatePath& path,
const std::vector<LeafIndex>& except);
void truncate(LeafCount size);
bool consistent(const TreeKEMPrivateKey& other) const;
bool consistent(const TreeKEMPublicKey& other) const;
#if ENABLE_TREE_DUMP
void dump() const;
#endif
// TODO(RLB) Make this private but exposed to test vectors
void implant(const TreeKEMPublicKey& pub,
NodeIndex start,
const bytes& path_secret);
};
struct TreeKEMPublicKey
{
CipherSuite suite;
LeafCount size{ 0 };
std::vector<OptionalNode> nodes;
explicit TreeKEMPublicKey(CipherSuite suite);
TreeKEMPublicKey() = default;
TreeKEMPublicKey(const TreeKEMPublicKey& other) = default;
TreeKEMPublicKey(TreeKEMPublicKey&& other) = default;
TreeKEMPublicKey& operator=(const TreeKEMPublicKey& other) = default;
TreeKEMPublicKey& operator=(TreeKEMPublicKey&& other) = default;
LeafIndex allocate_leaf();
LeafIndex add_leaf(const LeafNode& leaf);
void update_leaf(LeafIndex index, const LeafNode& leaf);
void blank_path(LeafIndex index);
TreeKEMPrivateKey update(LeafIndex from,
const bytes& leaf_secret,
const bytes& group_id,
const SignaturePrivateKey& sig_priv,
const LeafNodeOptions& opts);
UpdatePath encap(const TreeKEMPrivateKey& priv,
const bytes& context,
const std::vector<LeafIndex>& except) const;
void merge(LeafIndex from, const UpdatePath& path);
void set_hash_all();
const bytes& get_hash(NodeIndex index);
bytes root_hash() const;
bool parent_hash_valid(LeafIndex from, const UpdatePath& path) const;
bool parent_hash_valid() const;
bool has_leaf(LeafIndex index) const;
std::optional<LeafIndex> find(const LeafNode& leaf) const;
std::optional<LeafNode> leaf_node(LeafIndex index) const;
std::vector<NodeIndex> resolve(NodeIndex index) const;
template<typename UnaryPredicate>
bool all_leaves(const UnaryPredicate& pred) const
{
for (LeafIndex i{ 0 }; i < size; i.val++) {
const auto& node = node_at(i);
if (node.blank()) {
continue;
}
if (!pred(i, node.leaf_node())) {
return false;
}
}
return true;
}
template<typename UnaryPredicate>
bool any_leaf(const UnaryPredicate& pred) const
{
for (LeafIndex i{ 0 }; i < size; i.val++) {
const auto& node = node_at(i);
if (node.blank()) {
continue;
}
if (pred(i, node.leaf_node())) {
return true;
}
}
return false;
}
using FilteredDirectPath =
std::vector<std::tuple<NodeIndex, std::vector<NodeIndex>>>;
FilteredDirectPath filtered_direct_path(NodeIndex index) const;
void truncate();
OptionalNode& node_at(NodeIndex n);
const OptionalNode& node_at(NodeIndex n) const;
OptionalNode& node_at(LeafIndex n);
const OptionalNode& node_at(LeafIndex n) const;
TLS_SERIALIZABLE(nodes)
#if ENABLE_TREE_DUMP
void dump() const;
#endif
private:
std::map<NodeIndex, bytes> hashes;
void clear_hash_all();
void clear_hash_path(LeafIndex index);
bool has_parent_hash(NodeIndex child, const bytes& target_ph) const;
bytes parent_hash(const ParentNode& parent, NodeIndex copath_child) const;
std::vector<bytes> parent_hashes(
LeafIndex from,
const FilteredDirectPath& fdp,
const std::vector<UpdatePathNode>& path_nodes) const;
using TreeHashCache = std::map<NodeIndex, std::pair<size_t, bytes>>;
const bytes& original_tree_hash(TreeHashCache& cache,
NodeIndex index,
std::vector<LeafIndex> parent_except) const;
bytes original_parent_hash(TreeHashCache& cache,
NodeIndex parent,
NodeIndex sibling) const;
bool exists_in_tree(const HPKEPublicKey& key,
std::optional<LeafIndex> except) const;
bool exists_in_tree(const SignaturePublicKey& key,
std::optional<LeafIndex> except) const;
OptionalNode blank_node;
friend struct TreeKEMPrivateKey;
};
tls::ostream&
operator<<(tls::ostream& str, const TreeKEMPublicKey& obj);
tls::istream&
operator>>(tls::istream& str, TreeKEMPublicKey& obj);
struct LeafNodeHashInput;
struct ParentNodeHashInput;
} // namespace mlspp
namespace mlspp::tls {
TLS_VARIANT_MAP(mlspp::NodeType, mlspp::LeafNodeHashInput, leaf)
TLS_VARIANT_MAP(mlspp::NodeType,
mlspp::ParentNodeHashInput,
parent)
TLS_VARIANT_MAP(mlspp::NodeType, mlspp::LeafNode, leaf)
TLS_VARIANT_MAP(mlspp::NodeType, mlspp::ParentNode, parent)
} // namespace mlspp::tls

4
DPP/mlspp/include/namespace.h Executable file
View File

@@ -0,0 +1,4 @@
#pragma once
// Configurable top-level MLS namespace
#define MLS_NAMESPACE ../include/dpp/mlspp/mls

5
DPP/mlspp/include/version.h Executable file
View File

@@ -0,0 +1,5 @@
#pragma once
/* Global version strings */
extern const char VERSION[];
extern const char HASHVAR[];

4
DPP/mlspp/lib/CMakeLists.txt Executable file
View File

@@ -0,0 +1,4 @@
add_subdirectory(bytes)
add_subdirectory(hpke)
add_subdirectory(tls_syntax)
add_subdirectory(mls_vectors)

View File

@@ -0,0 +1,25 @@
set(CURRENT_LIB_NAME bytes)
###
### Library Config
###
file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES})
add_dependencies(${CURRENT_LIB_NAME} tls_syntax)
include_directories("${PROJECT_SOURCE_DIR}/../bytes/include")
target_link_libraries(${CURRENT_LIB_NAME} tls_syntax)
target_include_directories(${CURRENT_LIB_NAME}
PUBLIC
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include/${PROJECT_NAME}>
)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include")

View File

@@ -0,0 +1,127 @@
#pragma once
#include <string>
#include <tls/tls_syntax.h>
#include <vector>
namespace mlspp::bytes_ns {
struct bytes
{
// Ensure defaults
bytes() = default;
bytes(const bytes&) = default;
bytes& operator=(const bytes&) = default;
bytes(bytes&&) = default;
bytes& operator=(bytes&&) = default;
// Zeroize on drop
~bytes()
{
auto ptr = static_cast<volatile uint8_t*>(_data.data());
std::fill(ptr, ptr + _data.size(), uint8_t(0));
}
// Mimic std::vector ctors
bytes(size_t count, const uint8_t& value = 0)
: _data(count, value)
{
}
bytes(std::initializer_list<uint8_t> init)
: _data(init)
{
}
template<size_t N>
bytes(const std::array<uint8_t, N>& data)
: _data(data.begin(), data.end())
{
}
// Slice out sub-vectors (to avoid an iterator ctor)
bytes slice(size_t begin_index, size_t end_index) const
{
const auto begin_it = _data.begin() + begin_index;
const auto end_it = _data.begin() + end_index;
return std::vector<uint8_t>(begin_it, end_it);
}
// Freely convert to/from std::vector
bytes(const std::vector<uint8_t>& vec)
: _data(vec)
{
}
bytes(std::vector<uint8_t>&& vec)
: _data(vec)
{
}
operator const std::vector<uint8_t>&() const { return _data; }
operator std::vector<uint8_t>&() { return _data; }
operator std::vector<uint8_t>&&() && { return std::move(_data); }
const std::vector<uint8_t>& as_vec() const { return _data; }
std::vector<uint8_t>& as_vec() { return _data; }
// Pass through methods
auto data() const { return _data.data(); }
auto data() { return _data.data(); }
auto size() const { return _data.size(); }
auto empty() const { return _data.empty(); }
auto begin() const { return _data.begin(); }
auto begin() { return _data.begin(); }
auto end() const { return _data.end(); }
auto end() { return _data.end(); }
const auto& at(size_t pos) const { return _data.at(pos); }
auto& at(size_t pos) { return _data.at(pos); }
void resize(size_t count) { _data.resize(count); }
void reserve(size_t len) { _data.reserve(len); }
void push_back(uint8_t byte) { _data.push_back(byte); }
// Equality operators
bool operator==(const bytes& other) const;
bool operator!=(const bytes& other) const;
bool operator==(const std::vector<uint8_t>& other) const;
bool operator!=(const std::vector<uint8_t>& other) const;
// Arithmetic operators
bytes& operator+=(const bytes& other);
bytes operator+(const bytes& rhs) const;
bytes operator^(const bytes& rhs) const;
// Sorting operators (to allow usage as map keys)
bool operator<(const bytes& rhs) const;
// Other, external operators
friend std::ostream& operator<<(std::ostream& out, const bytes& data);
friend bool operator==(const std::vector<uint8_t>& lhs, const bytes& rhs);
friend bool operator!=(const std::vector<uint8_t>& lhs, const bytes& rhs);
// TLS syntax serialization
TLS_SERIALIZABLE(_data);
private:
std::vector<uint8_t> _data;
};
std::string
to_ascii(const bytes& data);
bytes
from_ascii(const std::string& ascii);
std::string
to_hex(const bytes& data);
bytes
from_hex(const std::string& hex);
} // namespace mlspp::bytes_ns

View File

@@ -0,0 +1,61 @@
#pragma once
#include <optional>
#include <stdexcept>
#ifdef VARIANT_COMPAT
#include <variant.hpp>
#else
#include <variant>
#endif // VARIANT_COMPAT
namespace mlspp::tls {
namespace var = std;
// In a similar vein, we provide our own safe accessors for std::optional, since
// std::optional::value() is not available on macOS 10.11.
namespace opt {
template<typename T>
T&
get(std::optional<T>& opt)
{
if (!opt) {
throw std::runtime_error("bad_optional_access");
}
return *opt;
}
template<typename T>
const T&
get(const std::optional<T>& opt)
{
if (!opt) {
throw std::runtime_error("bad_optional_access");
}
return *opt;
}
template<typename T>
T&&
get(std::optional<T>&& opt)
{
if (!opt) {
throw std::runtime_error("bad_optional_access");
}
return std::move(*opt);
}
template<typename T>
const T&&
get(const std::optional<T>&& opt)
{
if (!opt) {
throw std::runtime_error("bad_optional_access");
}
return std::move(*opt);
}
} // namespace opt
} // namespace mlspp::tls

View File

@@ -0,0 +1,569 @@
#pragma once
#include <algorithm>
#include <array>
#include <cstdint>
#include <limits>
#include <map>
#include <optional>
#include <stdexcept>
#include <vector>
#include <tls/compat.h>
namespace mlspp::tls {
// For indicating no min or max in vector definitions
const size_t none = std::numeric_limits<size_t>::max();
class WriteError : public std::invalid_argument
{
public:
using parent = std::invalid_argument;
using parent::parent;
};
class ReadError : public std::invalid_argument
{
public:
using parent = std::invalid_argument;
using parent::parent;
};
///
/// Declarations of Streams and Traits
///
class ostream
{
public:
static const size_t none = std::numeric_limits<size_t>::max();
void write_raw(const std::vector<uint8_t>& bytes);
const std::vector<uint8_t>& bytes() const { return _buffer; }
size_t size() const { return _buffer.size(); }
bool empty() const { return _buffer.empty(); }
private:
std::vector<uint8_t> _buffer;
ostream& write_uint(uint64_t value, int length);
friend ostream& operator<<(ostream& out, bool data);
friend ostream& operator<<(ostream& out, uint8_t data);
friend ostream& operator<<(ostream& out, uint16_t data);
friend ostream& operator<<(ostream& out, uint32_t data);
friend ostream& operator<<(ostream& out, uint64_t data);
template<typename T>
friend ostream& operator<<(ostream& out, const std::vector<T>& data);
friend struct varint;
};
class istream
{
public:
istream(const std::vector<uint8_t>& data)
: _buffer(data)
{
// So that we can use the constant-time pop_back
std::reverse(_buffer.begin(), _buffer.end());
}
size_t size() const { return _buffer.size(); }
bool empty() const { return _buffer.empty(); }
std::vector<uint8_t> bytes()
{
auto bytes = _buffer;
std::reverse(bytes.begin(), bytes.end());
return bytes;
}
private:
istream() {}
std::vector<uint8_t> _buffer;
uint8_t next();
template<typename T>
istream& read_uint(T& data, size_t length)
{
uint64_t value = 0;
for (size_t i = 0; i < length; i += 1) {
value = (value << unsigned(8)) + next();
}
data = static_cast<T>(value);
return *this;
}
friend istream& operator>>(istream& in, bool& data);
friend istream& operator>>(istream& in, uint8_t& data);
friend istream& operator>>(istream& in, uint16_t& data);
friend istream& operator>>(istream& in, uint32_t& data);
friend istream& operator>>(istream& in, uint64_t& data);
template<typename T>
friend istream& operator>>(istream& in, std::vector<T>& data);
friend struct varint;
};
// Traits must have static encode and decode methods, of the following form:
//
// static ostream& encode(ostream& str, const T& val);
// static istream& decode(istream& str, T& val);
//
// Trait types will never be constructed; only these static methods are used.
// The value arguments to encode and decode can be as strict or as loose as
// desired.
//
// Ultimately, all interesting encoding should be done through traits.
//
// * vectors
// * variants
// * varints
struct pass
{
template<typename T>
static ostream& encode(ostream& str, const T& val);
template<typename T>
static istream& decode(istream& str, T& val);
};
template<typename Ts>
struct variant
{
template<typename... Tp>
static inline Ts type(const var::variant<Tp...>& data);
template<typename... Tp>
static ostream& encode(ostream& str, const var::variant<Tp...>& data);
template<size_t I = 0, typename Te, typename... Tp>
static inline typename std::enable_if<I == sizeof...(Tp), void>::type
read_variant(istream&, Te, var::variant<Tp...>&);
template<size_t I = 0, typename Te, typename... Tp>
static inline typename std::enable_if <
I<sizeof...(Tp), void>::type read_variant(istream& str,
Te target_type,
var::variant<Tp...>& v);
template<typename... Tp>
static istream& decode(istream& str, var::variant<Tp...>& data);
};
struct varint
{
static ostream& encode(ostream& str, const uint64_t& val);
static istream& decode(istream& str, uint64_t& val);
};
///
/// Writer implementations
///
// Primitive writers defined in .cpp file
// Array writer
template<typename T, size_t N>
ostream&
operator<<(ostream& out, const std::array<T, N>& data)
{
for (const auto& item : data) {
out << item;
}
return out;
}
// Optional writer
template<typename T>
ostream&
operator<<(ostream& out, const std::optional<T>& opt)
{
if (!opt) {
return out << uint8_t(0);
}
return out << uint8_t(1) << opt::get(opt);
}
// Enum writer
template<typename T, std::enable_if_t<std::is_enum<T>::value, int> = 0>
ostream&
operator<<(ostream& str, const T& val)
{
auto u = static_cast<std::underlying_type_t<T>>(val);
return str << u;
}
// Vector writer
template<typename T>
ostream&
operator<<(ostream& str, const std::vector<T>& vec)
{
// Pre-encode contents
ostream temp;
for (const auto& item : vec) {
temp << item;
}
// Write the encoded length, then the pre-encoded data
varint::encode(str, temp._buffer.size());
str.write_raw(temp.bytes());
return str;
}
///
/// Reader implementations
///
// Primitive type readers defined in .cpp file
// Array reader
template<typename T, size_t N>
istream&
operator>>(istream& in, std::array<T, N>& data)
{
for (auto& item : data) {
in >> item;
}
return in;
}
// Optional reader
template<typename T>
istream&
operator>>(istream& in, std::optional<T>& opt)
{
uint8_t present = 0;
in >> present;
switch (present) {
case 0:
opt.reset();
return in;
case 1:
opt.emplace();
return in >> opt::get(opt);
default:
throw std::invalid_argument("Malformed optional");
}
}
// Enum reader
// XXX(rlb): It would be nice if this could enforce that the values are valid,
// but C++ doesn't seem to have that ability. When used as a tag for variants,
// the variant reader will enforce, at least.
template<typename T, std::enable_if_t<std::is_enum<T>::value, int> = 0>
istream&
operator>>(istream& str, T& val)
{
std::underlying_type_t<T> u;
str >> u;
val = static_cast<T>(u);
return str;
}
// Vector reader
template<typename T>
istream&
operator>>(istream& str, std::vector<T>& vec)
{
// Read the encoded data size
auto size = uint64_t(0);
varint::decode(str, size);
if (size > str._buffer.size()) {
throw ReadError("Vector is longer than remaining data");
}
// Read the elements of the vector
// NB: Remember that we store the vector in reverse order
// NB: This requires that T be default-constructible
istream r;
r._buffer =
std::vector<uint8_t>{ str._buffer.end() - size, str._buffer.end() };
vec.clear();
while (r._buffer.size() > 0) {
vec.emplace_back();
r >> vec.back();
}
// Truncate the primary buffer
str._buffer.erase(str._buffer.end() - size, str._buffer.end());
return str;
}
// Abbreviations
template<typename T>
std::vector<uint8_t>
marshal(const T& value)
{
ostream w;
w << value;
return w.bytes();
}
template<typename T>
void
unmarshal(const std::vector<uint8_t>& data, T& value)
{
istream r(data);
r >> value;
}
template<typename T, typename... Tp>
T
get(const std::vector<uint8_t>& data, Tp... args)
{
T value(args...);
unmarshal(data, value);
return value;
}
// Use this macro to define struct serialization with minimal boilerplate
#define TLS_SERIALIZABLE(...) \
static const bool _tls_serializable = true; \
auto _tls_fields_r() \
{ \
return std::forward_as_tuple(__VA_ARGS__); \
} \
auto _tls_fields_w() const \
{ \
return std::forward_as_tuple(__VA_ARGS__); \
}
// If your struct contains nontrivial members (e.g., vectors), use this to
// define traits for them.
#define TLS_TRAITS(...) \
static const bool _tls_has_traits = true; \
using _tls_traits = std::tuple<__VA_ARGS__>;
template<typename T>
struct is_serializable
{
template<typename U>
static std::true_type test(decltype(U::_tls_serializable));
template<typename U>
static std::false_type test(...);
static const bool value = decltype(test<T>(true))::value;
};
template<typename T>
struct has_traits
{
template<typename U>
static std::true_type test(decltype(U::_tls_has_traits));
template<typename U>
static std::false_type test(...);
static const bool value = decltype(test<T>(true))::value;
};
///
/// Trait implementations
///
// Pass-through (normal encoding/decoding)
template<typename T>
ostream&
pass::encode(ostream& str, const T& val)
{
return str << val;
}
template<typename T>
istream&
pass::decode(istream& str, T& val)
{
return str >> val;
}
// Variant encoding
template<typename Ts, typename Tv>
constexpr Ts
variant_map();
#define TLS_VARIANT_MAP(EnumType, MappedType, enum_value) \
template<> \
constexpr EnumType variant_map<EnumType, MappedType>() \
{ \
return EnumType::enum_value; \
}
template<typename Ts>
template<typename... Tp>
inline Ts
variant<Ts>::type(const var::variant<Tp...>& data)
{
const auto get_type = [](const auto& v) {
return variant_map<Ts, std::decay_t<decltype(v)>>();
};
return var::visit(get_type, data);
}
template<typename Ts>
template<typename... Tp>
ostream&
variant<Ts>::encode(ostream& str, const var::variant<Tp...>& data)
{
const auto write_variant = [&str](auto&& v) {
using Tv = std::decay_t<decltype(v)>;
str << variant_map<Ts, Tv>() << v;
};
var::visit(write_variant, data);
return str;
}
template<typename Ts>
template<size_t I, typename Te, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type
variant<Ts>::read_variant(istream&, Te, var::variant<Tp...>&)
{
throw ReadError("Invalid variant type label");
}
template<typename Ts>
template<size_t I, typename Te, typename... Tp>
inline
typename std::enable_if < I<sizeof...(Tp), void>::type
variant<Ts>::read_variant(istream& str,
Te target_type,
var::variant<Tp...>& v)
{
using Tc = var::variant_alternative_t<I, var::variant<Tp...>>;
if (variant_map<Ts, Tc>() == target_type) {
str >> v.template emplace<I>();
return;
}
read_variant<I + 1>(str, target_type, v);
}
template<typename Ts>
template<typename... Tp>
istream&
variant<Ts>::decode(istream& str, var::variant<Tp...>& data)
{
Ts target_type;
str >> target_type;
read_variant(str, target_type, data);
return str;
}
// Struct writer without traits (enabled by macro)
template<size_t I = 0, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type
write_tuple(ostream&, const std::tuple<Tp...>&)
{
}
template<size_t I = 0, typename... Tp>
inline typename std::enable_if <
I<sizeof...(Tp), void>::type
write_tuple(ostream& str, const std::tuple<Tp...>& t)
{
str << std::get<I>(t);
write_tuple<I + 1, Tp...>(str, t);
}
template<typename T>
inline
typename std::enable_if<is_serializable<T>::value && !has_traits<T>::value,
ostream&>::type
operator<<(ostream& str, const T& obj)
{
write_tuple(str, obj._tls_fields_w());
return str;
}
// Struct writer with traits (enabled by macro)
template<typename Tr, size_t I = 0, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type
write_tuple_traits(ostream&, const std::tuple<Tp...>&)
{
}
template<typename Tr, size_t I = 0, typename... Tp>
inline typename std::enable_if <
I<sizeof...(Tp), void>::type
write_tuple_traits(ostream& str, const std::tuple<Tp...>& t)
{
std::tuple_element_t<I, Tr>::encode(str, std::get<I>(t));
write_tuple_traits<Tr, I + 1, Tp...>(str, t);
}
template<typename T>
inline
typename std::enable_if<is_serializable<T>::value && has_traits<T>::value,
ostream&>::type
operator<<(ostream& str, const T& obj)
{
write_tuple_traits<typename T::_tls_traits>(str, obj._tls_fields_w());
return str;
}
// Struct reader without traits (enabled by macro)
template<size_t I = 0, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type
read_tuple(istream&, const std::tuple<Tp...>&)
{
}
template<size_t I = 0, typename... Tp>
inline
typename std::enable_if < I<sizeof...(Tp), void>::type
read_tuple(istream& str, const std::tuple<Tp...>& t)
{
str >> std::get<I>(t);
read_tuple<I + 1, Tp...>(str, t);
}
template<typename T>
inline
typename std::enable_if<is_serializable<T>::value && !has_traits<T>::value,
istream&>::type
operator>>(istream& str, T& obj)
{
read_tuple(str, obj._tls_fields_r());
return str;
}
// Struct reader with traits (enabled by macro)
template<typename Tr, size_t I = 0, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type
read_tuple_traits(istream&, const std::tuple<Tp...>&)
{
}
template<typename Tr, size_t I = 0, typename... Tp>
inline typename std::enable_if <
I<sizeof...(Tp), void>::type
read_tuple_traits(istream& str, const std::tuple<Tp...>& t)
{
std::tuple_element_t<I, Tr>::decode(str, std::get<I>(t));
read_tuple_traits<Tr, I + 1, Tp...>(str, t);
}
template<typename T>
inline
typename std::enable_if<is_serializable<T>::value && has_traits<T>::value,
istream&>::type
operator>>(istream& str, T& obj)
{
read_tuple_traits<typename T::_tls_traits>(str, obj._tls_fields_r());
return str;
}
} // namespace mlspp::tls

146
DPP/mlspp/lib/bytes/src/bytes.cpp Executable file
View File

@@ -0,0 +1,146 @@
#include <bytes/bytes.h>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <stdexcept>
namespace mlspp::bytes_ns {
bool
bytes::operator==(const bytes& other) const
{
return *this == other._data;
}
bool
bytes::operator!=(const bytes& other) const
{
return !(*this == other._data);
}
bool
bytes::operator==(const std::vector<uint8_t>& other) const
{
const size_t size = other.size();
if (_data.size() != size) {
return false;
}
unsigned char diff = 0;
for (size_t i = 0; i < size; ++i) {
// Not sure why the linter thinks `diff` is signed
// NOLINTNEXTLINE(hicpp-signed-bitwise)
diff |= (_data.at(i) ^ other.at(i));
}
return (diff == 0);
}
bool
bytes::operator!=(const std::vector<uint8_t>& other) const
{
return !(*this == other);
}
bytes&
bytes::operator+=(const bytes& other)
{
// Not sure what the default argument is here
// NOLINTNEXTLINE(fuchsia-default-arguments)
_data.insert(end(), other.begin(), other.end());
return *this;
}
bytes
bytes::operator+(const bytes& rhs) const
{
bytes out = *this;
out += rhs;
return out;
}
bool
bytes::operator<(const bytes& rhs) const
{
return _data < rhs._data;
}
bytes
bytes::operator^(const bytes& rhs) const
{
if (size() != rhs.size()) {
throw std::invalid_argument("XOR with unequal size");
}
bytes out = *this;
for (size_t i = 0; i < size(); ++i) {
out.at(i) ^= rhs.at(i);
}
return out;
}
std::string
to_ascii(const bytes& data)
{
return { data.begin(), data.end() };
}
bytes
from_ascii(const std::string& ascii)
{
return std::vector<uint8_t>(ascii.begin(), ascii.end());
}
std::string
to_hex(const bytes& data)
{
std::stringstream hex(std::ios_base::out);
hex.flags(std::ios::hex);
for (const auto& byte : data) {
hex << std::setw(2) << std::setfill('0') << int(byte);
}
return hex.str();
}
bytes
from_hex(const std::string& hex)
{
if (hex.length() % 2 == 1) {
throw std::invalid_argument("Odd-length hex string");
}
auto len = hex.length() / 2;
auto out = bytes(len);
for (size_t i = 0; i < len; i += 1) {
const std::string byte = hex.substr(2 * i, 2);
out.at(i) = static_cast<uint8_t>(strtol(byte.c_str(), nullptr, 16));
}
return out;
}
std::ostream&
operator<<(std::ostream& out, const bytes& data)
{
// Adjust this threshold to make output more compact
const size_t threshold = 0xffff;
if (data.size() < threshold) {
return out << to_hex(data);
}
return out << to_hex(data.slice(0, threshold)) << "...";
}
bool
operator==(const std::vector<uint8_t>& lhs, const bytes_ns::bytes& rhs)
{
return rhs == lhs;
}
bool
operator!=(const std::vector<uint8_t>& lhs, const bytes_ns::bytes& rhs)
{
return rhs != lhs;
}
} // namespace mlspp::bytes_ns

View File

@@ -0,0 +1,57 @@
set(CURRENT_LIB_NAME hpke)
###
### Library Config
###
file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
# -Werror=dangling-reference
add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES})
add_dependencies(${CURRENT_LIB_NAME} bytes tls_syntax)
target_include_directories(${CURRENT_LIB_NAME}
PRIVATE
"${JSON_INCLUDE_INTERFACE}")
target_link_libraries(${CURRENT_LIB_NAME}
PUBLIC
bytes tls_syntax
)
target_include_directories(${CURRENT_LIB_NAME}
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include/${PROJECT_NAME}>
PRIVATE
${OPENSSL_INCLUDE_DIR}
)
# Private statically linked dependencies
target_link_libraries(${CURRENT_LIB_NAME} PRIVATE
mlspp
mls_vectors
bytes
tls_syntax
)
target_compile_options(
"${CURRENT_LIB_NAME}"
PUBLIC
"$<$<PLATFORM_ID:Windows>:/bigobj;/Zc:preprocessor>"
PRIVATE
"$<$<PLATFORM_ID:Windows>:$<$<CONFIG:Debug>:/sdl;/Od;/DEBUG;/MP;/DFD_SETSIZE=1024>>"
"$<$<PLATFORM_ID:Windows>:$<$<CONFIG:Release>:/O2;/Oi;/Oy;/GL;/Gy;/sdl;/MP;/DFD_SETSIZE=1024>>"
"$<$<PLATFORM_ID:Linux>:$<$<CONFIG:Debug>:-Wall;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-g;-Og;-fPIC>>"
"$<$<PLATFORM_ID:Linux>:$<$<CONFIG:Release>:-Wall;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-O3;-fPIC>>"
"${AVX_FLAG}"
)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include")
# For nlohmann/json.hpp
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../../include")

View File

@@ -0,0 +1,20 @@
#pragma once
#include <bytes/bytes.h>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
std::string
to_base64(const bytes& data);
std::string
to_base64url(const bytes& data);
bytes
from_base64(const std::string& enc);
bytes
from_base64url(const std::string& enc);
} // namespace mlspp::hpke

View File

@@ -0,0 +1,75 @@
#pragma once
#include <memory>
#include <optional>
#include <bytes/bytes.h>
#include <chrono>
#include <hpke/signature.h>
#include <map>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
struct Certificate
{
private:
struct ParsedCertificate;
std::unique_ptr<ParsedCertificate> parsed_cert;
public:
struct NameType
{
static const int organization;
static const int common_name;
static const int organizational_unit;
static const int country;
static const int serial_number;
static const int state_or_province_name;
};
using ParsedName = std::map<int, std::string>;
// Certificate Expiration Status
enum struct ExpirationStatus
{
inactive, // now < notBefore
active, // notBefore < now < notAfter
expired, // notAfter < now
};
explicit Certificate(const bytes& der);
explicit Certificate(std::unique_ptr<ParsedCertificate>&& parsed_cert_in);
Certificate() = delete;
Certificate(const Certificate& other);
~Certificate();
static std::vector<Certificate> parse_pem(const bytes& pem);
bool valid_from(const Certificate& parent) const;
// Accessors for parsed certificate elements
uint64_t issuer_hash() const;
uint64_t subject_hash() const;
ParsedName issuer() const;
ParsedName subject() const;
bool is_ca() const;
ExpirationStatus expiration_status() const;
std::optional<bytes> subject_key_id() const;
std::optional<bytes> authority_key_id() const;
std::vector<std::string> email_addresses() const;
std::vector<std::string> dns_names() const;
bytes hash() const;
std::chrono::system_clock::time_point not_before() const;
std::chrono::system_clock::time_point not_after() const;
Signature::ID public_key_algorithm() const;
Signature::ID signature_algorithm() const;
const std::unique_ptr<Signature::PublicKey> public_key;
const bytes raw;
};
bool
operator==(const Certificate& lhs, const Certificate& rhs);
} // namespace mlspp::hpke

View File

@@ -0,0 +1,37 @@
#pragma once
#include <memory>
#include <bytes/bytes.h>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
struct Digest
{
enum struct ID
{
SHA256,
SHA384,
SHA512,
};
template<ID id>
static const Digest& get();
const ID id;
bytes hash(const bytes& data) const;
bytes hmac(const bytes& key, const bytes& data) const;
const size_t hash_size;
private:
explicit Digest(ID id);
bytes hmac_for_hkdf_extract(const bytes& key, const bytes& data) const;
friend struct HKDF;
};
} // namespace mlspp::hpke

View File

@@ -0,0 +1,253 @@
#pragma once
#include <memory>
#include <optional>
#include <bytes/bytes.h>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
struct KEM
{
enum struct ID : uint16_t
{
DHKEM_P256_SHA256 = 0x0010,
DHKEM_P384_SHA384 = 0x0011,
DHKEM_P521_SHA512 = 0x0012,
DHKEM_X25519_SHA256 = 0x0020,
#if !defined(WITH_BORINGSSL)
DHKEM_X448_SHA512 = 0x0021,
#endif
};
template<KEM::ID>
static const KEM& get();
virtual ~KEM() = default;
struct PublicKey
{
virtual ~PublicKey() = default;
};
struct PrivateKey
{
virtual ~PrivateKey() = default;
virtual std::unique_ptr<PublicKey> public_key() const = 0;
};
const ID id;
const size_t secret_size;
const size_t enc_size;
const size_t pk_size;
const size_t sk_size;
virtual std::unique_ptr<PrivateKey> generate_key_pair() const = 0;
virtual std::unique_ptr<PrivateKey> derive_key_pair(
const bytes& ikm) const = 0;
virtual bytes serialize(const PublicKey& pk) const = 0;
virtual std::unique_ptr<PublicKey> deserialize(const bytes& enc) const = 0;
virtual bytes serialize_private(const PrivateKey& sk) const;
virtual std::unique_ptr<PrivateKey> deserialize_private(
const bytes& skm) const;
// (shared_secret, enc)
virtual std::pair<bytes, bytes> encap(const PublicKey& pkR) const = 0;
virtual bytes decap(const bytes& enc, const PrivateKey& skR) const = 0;
// (shared_secret, enc)
virtual std::pair<bytes, bytes> auth_encap(const PublicKey& pkR,
const PrivateKey& skS) const;
virtual bytes auth_decap(const bytes& enc,
const PublicKey& pkS,
const PrivateKey& skR) const;
protected:
KEM(ID id_in,
size_t secret_size_in,
size_t enc_size_in,
size_t pk_size_in,
size_t sk_size_in);
};
struct KDF
{
enum struct ID : uint16_t
{
HKDF_SHA256 = 0x0001,
HKDF_SHA384 = 0x0002,
HKDF_SHA512 = 0x0003,
};
template<KDF::ID id>
static const KDF& get();
virtual ~KDF() = default;
const ID id;
const size_t hash_size;
virtual bytes extract(const bytes& salt, const bytes& ikm) const = 0;
virtual bytes expand(const bytes& prk,
const bytes& info,
size_t size) const = 0;
bytes labeled_extract(const bytes& suite_id,
const bytes& salt,
const bytes& label,
const bytes& ikm) const;
bytes labeled_expand(const bytes& suite_id,
const bytes& prk,
const bytes& label,
const bytes& info,
size_t size) const;
protected:
KDF(ID id_in, size_t hash_size_in);
};
struct AEAD
{
enum struct ID : uint16_t
{
AES_128_GCM = 0x0001,
AES_256_GCM = 0x0002,
CHACHA20_POLY1305 = 0x0003,
// Reserved identifier for pseudo-AEAD on contexts that only allow export
export_only = 0xffff,
};
template<AEAD::ID id>
static const AEAD& get();
virtual ~AEAD() = default;
const ID id;
const size_t key_size;
const size_t nonce_size;
virtual bytes seal(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& pt) const = 0;
virtual std::optional<bytes> open(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& ct) const = 0;
protected:
AEAD(ID id_in, size_t key_size_in, size_t nonce_size_in);
};
struct Context
{
bytes do_export(const bytes& exporter_context, size_t size) const;
protected:
bytes suite;
bytes key;
bytes nonce;
bytes exporter_secret;
const KDF& kdf;
const AEAD& aead;
bytes current_nonce() const;
void increment_seq();
private:
uint64_t seq;
Context(bytes suite_in,
bytes key_in,
bytes nonce_in,
bytes exporter_secret_in,
const KDF& kdf_in,
const AEAD& aead_in);
friend struct HPKE;
friend struct HPKETest;
friend bool operator==(const Context& lhs, const Context& rhs);
};
struct SenderContext : public Context
{
SenderContext(Context&& c);
bytes seal(const bytes& aad, const bytes& pt);
};
struct ReceiverContext : public Context
{
ReceiverContext(Context&& c);
std::optional<bytes> open(const bytes& aad, const bytes& ct);
};
struct HPKE
{
enum struct Mode : uint8_t
{
base = 0,
psk = 1,
auth = 2,
auth_psk = 3,
};
HPKE(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id);
using SenderInfo = std::pair<bytes, SenderContext>;
SenderInfo setup_base_s(const KEM::PublicKey& pkR, const bytes& info) const;
ReceiverContext setup_base_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info) const;
SenderInfo setup_psk_s(const KEM::PublicKey& pkR,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const;
ReceiverContext setup_psk_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const;
SenderInfo setup_auth_s(const KEM::PublicKey& pkR,
const bytes& info,
const KEM::PrivateKey& skS) const;
ReceiverContext setup_auth_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const KEM::PublicKey& pkS) const;
SenderInfo setup_auth_psk_s(const KEM::PublicKey& pkR,
const bytes& info,
const bytes& psk,
const bytes& psk_id,
const KEM::PrivateKey& skS) const;
ReceiverContext setup_auth_psk_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const bytes& psk,
const bytes& psk_id,
const KEM::PublicKey& pkS) const;
bytes suite;
const KEM& kem;
const KDF& kdf;
const AEAD& aead;
private:
static bool verify_psk_inputs(Mode mode,
const bytes& psk,
const bytes& psk_id);
Context key_schedule(Mode mode,
const bytes& shared_secret,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const;
};
} // namespace mlspp::hpke

View File

@@ -0,0 +1,11 @@
#pragma once
#include <bytes/bytes.h>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
bytes
random_bytes(size_t size);
} // namespace mlspp::hpke

View File

@@ -0,0 +1,89 @@
#pragma once
#include <memory>
#include <bytes/bytes.h>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
struct Signature
{
enum struct ID
{
P256_SHA256,
P384_SHA384,
P521_SHA512,
Ed25519,
#if !defined(WITH_BORINGSSL)
Ed448,
#endif
RSA_SHA256,
RSA_SHA384,
RSA_SHA512,
};
template<Signature::ID id>
static const Signature& get();
virtual ~Signature() = default;
struct PublicKey
{
virtual ~PublicKey() = default;
};
struct PrivateKey
{
virtual ~PrivateKey() = default;
virtual std::unique_ptr<PublicKey> public_key() const = 0;
};
const ID id;
virtual std::unique_ptr<PrivateKey> generate_key_pair() const = 0;
virtual std::unique_ptr<PrivateKey> derive_key_pair(
const bytes& ikm) const = 0;
virtual bytes serialize(const PublicKey& pk) const = 0;
virtual std::unique_ptr<PublicKey> deserialize(const bytes& enc) const = 0;
virtual bytes serialize_private(const PrivateKey& sk) const = 0;
virtual std::unique_ptr<PrivateKey> deserialize_private(
const bytes& skm) const = 0;
struct PrivateJWK
{
const Signature& sig;
std::optional<std::string> key_id;
std::unique_ptr<PrivateKey> key;
};
static PrivateJWK parse_jwk_private(const std::string& jwk_json);
struct PublicJWK
{
const Signature& sig;
std::optional<std::string> key_id;
std::unique_ptr<PublicKey> key;
};
static PublicJWK parse_jwk(const std::string& jwk_json);
virtual std::unique_ptr<PrivateKey> import_jwk_private(
const std::string& jwk_json) const = 0;
virtual std::unique_ptr<PublicKey> import_jwk(
const std::string& jwk_json) const = 0;
virtual std::string export_jwk_private(const PrivateKey& env) const = 0;
virtual std::string export_jwk(const PublicKey& env) const = 0;
virtual bytes sign(const bytes& data, const PrivateKey& sk) const = 0;
virtual bool verify(const bytes& data,
const bytes& sig,
const PublicKey& pk) const = 0;
static std::unique_ptr<PrivateKey> generate_rsa(size_t bits);
protected:
Signature(ID id_in);
};
} // namespace mlspp::hpke

View File

@@ -0,0 +1,82 @@
#pragma once
#include <memory>
#include <optional>
#include <bytes/bytes.h>
#include <chrono>
#include <hpke/signature.h>
#include <map>
using namespace mlspp::bytes_ns;
namespace mlspp::hpke {
struct UserInfoClaimsAddress
{
std::optional<std::string> formatted;
std::optional<std::string> street_address;
std::optional<std::string> locality;
std::optional<std::string> region;
std::optional<std::string> postal_code;
std::optional<std::string> country;
};
struct UserInfoClaims
{
std::optional<std::string> sub;
std::optional<std::string> name;
std::optional<std::string> given_name;
std::optional<std::string> family_name;
std::optional<std::string> middle_name;
std::optional<std::string> nickname;
std::optional<std::string> preferred_username;
std::optional<std::string> profile;
std::optional<std::string> picture;
std::optional<std::string> website;
std::optional<std::string> email;
std::optional<bool> email_verified;
std::optional<std::string> gender;
std::optional<std::string> birthdate;
std::optional<std::string> zoneinfo;
std::optional<std::string> locale;
std::optional<std::string> phone_number;
std::optional<bool> phone_number_verified;
std::optional<UserInfoClaimsAddress> address;
std::optional<uint64_t> updated_at;
static UserInfoClaims from_json(const std::string& cred_subject);
};
struct UserInfoVC
{
private:
struct ParsedCredential;
std::shared_ptr<ParsedCredential> parsed_cred;
public:
explicit UserInfoVC(std::string jwt);
UserInfoVC() = default;
UserInfoVC(const UserInfoVC& other) = default;
~UserInfoVC() = default;
UserInfoVC& operator=(const UserInfoVC& other) = default;
UserInfoVC& operator=(UserInfoVC&& other) = default;
const Signature& signature_algorithm() const;
std::string issuer() const;
std::optional<std::string> key_id() const;
std::chrono::system_clock::time_point not_before() const;
std::chrono::system_clock::time_point not_after() const;
const std::string& raw_credential() const;
const UserInfoClaims& subject() const;
const Signature::PublicJWK& public_key() const;
bool valid_from(const Signature::PublicKey& issuer_key) const;
std::string raw;
};
bool
operator==(const UserInfoVC& lhs, const UserInfoVC& rhs);
} // namespace mlspp::hpke

View File

@@ -0,0 +1,321 @@
#include "aead_cipher.h"
#include "openssl_common.h"
#include <openssl/evp.h>
#if WITH_BORINGSSL
#include <openssl/aead.h>
#endif
namespace mlspp::hpke {
///
/// ExportOnlyCipher
///
bytes
ExportOnlyCipher::seal(const bytes& /* key */,
const bytes& /* nonce */,
const bytes& /* aad */,
const bytes& /* pt */) const
{
throw std::runtime_error("seal() on export-only context");
}
std::optional<bytes>
ExportOnlyCipher::open(const bytes& /* key */,
const bytes& /* nonce */,
const bytes& /* aad */,
const bytes& /* ct */) const
{
throw std::runtime_error("open() on export-only context");
}
ExportOnlyCipher::ExportOnlyCipher()
: AEAD(AEAD::ID::export_only, 0, 0)
{
}
///
/// AEADCipher
///
AEADCipher
make_aead(AEAD::ID cipher_in)
{
return { cipher_in };
}
template<>
const AEADCipher&
AEADCipher::get<AEAD::ID::AES_128_GCM>()
{
static const auto instance = make_aead(AEAD::ID::AES_128_GCM);
return instance;
}
template<>
const AEADCipher&
AEADCipher::get<AEAD::ID::AES_256_GCM>()
{
static const auto instance = make_aead(AEAD::ID::AES_256_GCM);
return instance;
}
template<>
const AEADCipher&
AEADCipher::get<AEAD::ID::CHACHA20_POLY1305>()
{
static const auto instance = make_aead(AEAD::ID::CHACHA20_POLY1305);
return instance;
}
static size_t
cipher_key_size(AEAD::ID cipher)
{
switch (cipher) {
case AEAD::ID::AES_128_GCM:
return 16;
case AEAD::ID::AES_256_GCM:
case AEAD::ID::CHACHA20_POLY1305:
return 32;
default:
throw std::runtime_error("Unsupported algorithm");
}
}
static size_t
cipher_nonce_size(AEAD::ID cipher)
{
switch (cipher) {
case AEAD::ID::AES_128_GCM:
case AEAD::ID::AES_256_GCM:
case AEAD::ID::CHACHA20_POLY1305:
return 12;
default:
throw std::runtime_error("Unsupported algorithm");
}
}
static size_t
cipher_tag_size(AEAD::ID cipher)
{
switch (cipher) {
case AEAD::ID::AES_128_GCM:
case AEAD::ID::AES_256_GCM:
case AEAD::ID::CHACHA20_POLY1305:
return 16;
default:
throw std::runtime_error("Unsupported algorithm");
}
}
#if WITH_BORINGSSL
static const EVP_AEAD*
boringssl_cipher(AEAD::ID cipher)
{
switch (cipher) {
case AEAD::ID::AES_128_GCM:
return EVP_aead_aes_128_gcm();
case AEAD::ID::AES_256_GCM:
return EVP_aead_aes_256_gcm();
case AEAD::ID::CHACHA20_POLY1305:
return EVP_aead_chacha20_poly1305();
default:
throw std::runtime_error("Unsupported algorithm");
}
}
#else
static const EVP_CIPHER*
openssl_cipher(AEAD::ID cipher)
{
switch (cipher) {
case AEAD::ID::AES_128_GCM:
return EVP_aes_128_gcm();
case AEAD::ID::AES_256_GCM:
return EVP_aes_256_gcm();
case AEAD::ID::CHACHA20_POLY1305:
return EVP_chacha20_poly1305();
default:
throw std::runtime_error("Unsupported algorithm");
}
}
#endif // WITH_BORINGSSL
AEADCipher::AEADCipher(AEAD::ID id_in)
: AEAD(id_in, cipher_key_size(id_in), cipher_nonce_size(id_in))
, tag_size(cipher_tag_size(id))
{
}
bytes
AEADCipher::seal(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& pt) const
{
#if WITH_BORINGSSL
auto ctx = make_typed_unique(
EVP_AEAD_CTX_new(boringssl_cipher(id), key.data(), key.size(), tag_size));
if (ctx == nullptr) {
throw openssl_error();
}
auto ct = bytes(pt.size() + tag_size);
auto out_len = ct.size();
if (1 != EVP_AEAD_CTX_seal(ctx.get(),
ct.data(),
&out_len,
ct.size(),
nonce.data(),
nonce.size(),
pt.data(),
pt.size(),
aad.data(),
aad.size())) {
throw openssl_error();
}
return ct;
#else
auto ctx = make_typed_unique(EVP_CIPHER_CTX_new());
if (ctx == nullptr) {
throw openssl_error();
}
const auto* cipher = openssl_cipher(id);
if (1 != EVP_EncryptInit(ctx.get(), cipher, key.data(), nonce.data())) {
throw openssl_error();
}
int outlen = 0;
if (!aad.empty()) {
if (1 != EVP_EncryptUpdate(ctx.get(),
nullptr,
&outlen,
aad.data(),
static_cast<int>(aad.size()))) {
throw openssl_error();
}
}
bytes ct(pt.size());
if (1 != EVP_EncryptUpdate(ctx.get(),
ct.data(),
&outlen,
pt.data(),
static_cast<int>(pt.size()))) {
throw openssl_error();
}
// Providing nullptr as an argument is safe here because this
// function never writes with GCM; it only computes the tag
if (1 != EVP_EncryptFinal(ctx.get(), nullptr, &outlen)) {
throw openssl_error();
}
bytes tag(tag_size);
if (1 != EVP_CIPHER_CTX_ctrl(ctx.get(),
EVP_CTRL_GCM_GET_TAG,
static_cast<int>(tag_size),
tag.data())) {
throw openssl_error();
}
ct += tag;
return ct;
#endif // WITH_BORINGSSL
}
std::optional<bytes>
AEADCipher::open(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& ct) const
{
if (ct.size() < tag_size) {
throw std::runtime_error("AEAD ciphertext smaller than tag size");
}
#if WITH_BORINGSSL
auto ctx = make_typed_unique(EVP_AEAD_CTX_new(
boringssl_cipher(id), key.data(), key.size(), cipher_tag_size(id)));
if (ctx == nullptr) {
throw openssl_error();
}
auto pt = bytes(ct.size() - tag_size);
auto out_len = pt.size();
if (1 != EVP_AEAD_CTX_open(ctx.get(),
pt.data(),
&out_len,
pt.size(),
nonce.data(),
nonce.size(),
ct.data(),
ct.size(),
aad.data(),
aad.size())) {
throw openssl_error();
}
return pt;
#else
auto ctx = make_typed_unique(EVP_CIPHER_CTX_new());
if (ctx == nullptr) {
throw openssl_error();
}
const auto* cipher = openssl_cipher(id);
if (1 != EVP_DecryptInit(ctx.get(), cipher, key.data(), nonce.data())) {
throw openssl_error();
}
auto inner_ct_size = ct.size() - tag_size;
auto tag = ct.slice(inner_ct_size, ct.size());
if (1 != EVP_CIPHER_CTX_ctrl(ctx.get(),
EVP_CTRL_GCM_SET_TAG,
static_cast<int>(tag_size),
tag.data())) {
throw openssl_error();
}
int out_size = 0;
if (!aad.empty()) {
if (1 != EVP_DecryptUpdate(ctx.get(),
nullptr,
&out_size,
aad.data(),
static_cast<int>(aad.size()))) {
throw openssl_error();
}
}
bytes pt(inner_ct_size);
if (1 != EVP_DecryptUpdate(ctx.get(),
pt.data(),
&out_size,
ct.data(),
static_cast<int>(inner_ct_size))) {
throw openssl_error();
}
// Providing nullptr as an argument is safe here because this
// function never writes with GCM; it only verifies the tag
if (1 != EVP_DecryptFinal(ctx.get(), nullptr, &out_size)) {
throw std::runtime_error("AEAD authentication failure");
}
return pt;
#endif // WITH_BORINGSSL
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,45 @@
#pragma once
#include <hpke/hpke.h>
namespace mlspp::hpke {
struct ExportOnlyCipher : public AEAD
{
ExportOnlyCipher();
~ExportOnlyCipher() override = default;
bytes seal(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& pt) const override;
std::optional<bytes> open(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& ct) const override;
};
struct AEADCipher : public AEAD
{
template<AEAD::ID id>
static const AEADCipher& get();
~AEADCipher() override = default;
bytes seal(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& pt) const override;
std::optional<bytes> open(const bytes& key,
const bytes& nonce,
const bytes& aad,
const bytes& ct) const override;
private:
const size_t tag_size;
AEADCipher(AEAD::ID id_in);
friend AEADCipher make_aead(AEAD::ID cipher_in);
};
} // namespace mlspp::hpke

105
DPP/mlspp/lib/hpke/src/base64.cpp Executable file
View File

@@ -0,0 +1,105 @@
#include <hpke/base64.h>
#include "openssl_common.h"
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/evp.h>
namespace mlspp::hpke {
std::string
to_base64(const bytes& data)
{
if (data.empty()) {
return "";
}
#if WITH_BORINGSSL
const auto data_size = data.size();
#else
const auto data_size = static_cast<int>(data.size());
#endif
// base64 encoding produces 4 characters for every 3 input bytes (rounded up)
const auto out_size = (data_size + 2) / 3 * 4;
auto out = bytes(out_size + 1); // NUL terminator
const auto result = EVP_EncodeBlock(out.data(), data.data(), data_size);
if (result != out_size) {
throw openssl_error();
}
out.resize(out.size() - 1); // strip NUL terminator
return to_ascii(out);
}
std::string
to_base64url(const bytes& data)
{
if (data.empty()) {
return "";
}
auto encoded = to_base64(data);
auto pad_start = encoded.find_first_of('=');
if (pad_start != std::string::npos) {
encoded = encoded.substr(0, pad_start);
}
std::replace(encoded.begin(), encoded.end(), '+', '-');
std::replace(encoded.begin(), encoded.end(), '/', '_');
return encoded;
}
bytes
from_base64(const std::string& enc)
{
if (enc.length() == 0) {
return {};
}
if (enc.length() % 4 != 0) {
throw std::runtime_error("Base64 length is not divisible by 4");
}
const auto in = from_ascii(enc);
const auto in_size = static_cast<int>(in.size());
const auto out_size = in_size / 4 * 3;
auto out = bytes(out_size);
const auto result = EVP_DecodeBlock(out.data(), in.data(), in_size);
if (result != out_size) {
throw openssl_error();
}
if (enc.substr(enc.length() - 2, enc.length()) == "==") {
out.resize(out.size() - 2);
} else if (enc.substr(enc.length() - 1, enc.length()) == "=") {
out.resize(out.size() - 1);
}
return out;
}
bytes
from_base64url(const std::string& enc)
{
if (enc.empty()) {
return {};
}
auto enc_copy = enc;
std::replace(enc_copy.begin(), enc_copy.end(), '-', '+');
std::replace(enc_copy.begin(), enc_copy.end(), '_', '/');
while (enc_copy.length() % 4 != 0) {
enc_copy += "=";
}
return from_base64(enc_copy);
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,539 @@
#include "group.h"
#include "openssl_common.h"
#include "rsa.h"
#include <hpke/certificate.h>
#include <hpke/signature.h>
#include <memory>
#include <openssl/err.h>
#include <openssl/pem.h>
#include <openssl/x509.h>
#include <openssl/x509v3.h>
#include <tls/compat.h>
namespace mlspp::hpke {
///
/// Utility functions
///
static std::optional<bytes>
asn1_octet_string_to_bytes(const ASN1_OCTET_STRING* octets)
{
if (octets == nullptr) {
return std::nullopt;
}
const auto* ptr = ASN1_STRING_get0_data(octets);
const auto len = ASN1_STRING_length(octets);
// NOLINTNEXTLINE (cppcoreguidelines-pro-bounds-pointer-arithmetic)
return std::vector<uint8_t>(ptr, ptr + len);
}
static std::string
asn1_string_to_std_string(const ASN1_STRING* asn1_string)
{
const auto* data = ASN1_STRING_get0_data(asn1_string);
const auto data_size = static_cast<size_t>(ASN1_STRING_length(asn1_string));
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
auto str = std::string(reinterpret_cast<const char*>(data));
if (str.size() != data_size) {
throw std::runtime_error("Malformed ASN.1 string");
}
return str;
}
static std::chrono::system_clock::time_point
asn1_time_to_chrono(const ASN1_TIME* asn1_time)
{
auto epoch_chrono = std::chrono::system_clock::time_point();
auto epoch_time_t = std::chrono::system_clock::to_time_t(epoch_chrono);
auto epoch_asn1 = make_typed_unique(ASN1_TIME_set(nullptr, epoch_time_t));
if (!epoch_asn1) {
throw openssl_error();
}
auto secs = int(0);
auto days = int(0);
if (ASN1_TIME_diff(&days, &secs, epoch_asn1.get(), asn1_time) != 1) {
throw openssl_error();
}
auto delta = std::chrono::seconds(secs) + std::chrono::hours(24 * days);
return std::chrono::system_clock::time_point(delta);
}
///
/// ParsedCertificate
///
const int Certificate::NameType::organization = NID_organizationName;
const int Certificate::NameType::common_name = NID_commonName;
const int Certificate::NameType::organizational_unit =
NID_organizationalUnitName;
const int Certificate::NameType::country = NID_countryName;
const int Certificate::NameType::serial_number = NID_serialNumber;
const int Certificate::NameType::state_or_province_name =
NID_stateOrProvinceName;
struct RFC822Name
{
std::string value;
};
struct DNSName
{
std::string value;
};
using GeneralName = tls::var::variant<RFC822Name, DNSName>;
struct Certificate::ParsedCertificate
{
static std::unique_ptr<ParsedCertificate> parse(const bytes& der)
{
const auto* buf = der.data();
auto cert =
make_typed_unique(d2i_X509(nullptr, &buf, static_cast<int>(der.size())));
if (cert == nullptr) {
throw openssl_error();
}
return std::make_unique<ParsedCertificate>(cert.release());
}
static bytes compute_digest(const X509* cert)
{
const auto* md = EVP_sha256();
auto digest = bytes(EVP_MD_size(md));
unsigned int out_size = 0;
if (1 != X509_digest(cert, md, digest.data(), &out_size)) {
throw openssl_error();
}
return digest;
}
// Note: This method does not implement total general name parsing.
// Duplicate entries are not supported; if they are present, the last one
// presented by OpenSSL is chosen.
static ParsedName parse_names(const X509_NAME* x509_name)
{
if (x509_name == nullptr) {
throw openssl_error();
}
ParsedName parsed_name;
for (int i = X509_NAME_entry_count(x509_name) - 1; i >= 0; i--) {
auto* entry = X509_NAME_get_entry(x509_name, i);
if (entry == nullptr) {
continue;
}
auto* oid = X509_NAME_ENTRY_get_object(entry);
auto* asn_str = X509_NAME_ENTRY_get_data(entry);
if (oid == nullptr || asn_str == nullptr) {
continue;
}
const int nid = OBJ_obj2nid(oid);
const std::string parsed_value = asn1_string_to_std_string(asn_str);
parsed_name[nid] = parsed_value;
}
return parsed_name;
}
// Parse Subject Key Identifier Extension
static std::optional<bytes> parse_skid(X509* cert)
{
return asn1_octet_string_to_bytes(X509_get0_subject_key_id(cert));
}
// Parse Authority Key Identifier
static std::optional<bytes> parse_akid(X509* cert)
{
return asn1_octet_string_to_bytes(X509_get0_authority_key_id(cert));
}
static std::vector<GeneralName> parse_san(X509* cert)
{
std::vector<GeneralName> names;
#ifdef WITH_BORINGSSL
using san_names_nb_t = size_t;
#else
using san_names_nb_t = int;
#endif
san_names_nb_t san_names_nb = 0;
auto* ext_ptr =
X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
auto* san_ptr = reinterpret_cast<STACK_OF(GENERAL_NAME)*>(ext_ptr);
const auto san_names = make_typed_unique(san_ptr);
san_names_nb = sk_GENERAL_NAME_num(san_names.get());
// Check each name within the extension
for (san_names_nb_t i = 0; i < san_names_nb; i++) {
auto* current_name = sk_GENERAL_NAME_value(san_names.get(), i);
if (current_name->type == GEN_DNS) {
const auto dns_name =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access)
asn1_string_to_std_string(current_name->d.dNSName);
names.emplace_back(DNSName{ dns_name });
} else if (current_name->type == GEN_EMAIL) {
const auto email =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access
asn1_string_to_std_string(current_name->d.rfc822Name);
names.emplace_back(RFC822Name{ email });
}
}
return names;
}
explicit ParsedCertificate(X509* x509_in)
: x509(x509_in, typed_delete<X509>)
, pub_key_id(public_key_algorithm(x509.get()))
, sig_algo(signature_algorithm(x509.get()))
, issuer_hash(X509_issuer_name_hash(x509.get()))
, subject_hash(X509_subject_name_hash(x509.get()))
, issuer(parse_names(X509_get_issuer_name(x509.get())))
, subject(parse_names(X509_get_subject_name(x509.get())))
, subject_key_id(parse_skid(x509.get()))
, authority_key_id(parse_akid(x509.get()))
, sub_alt_names(parse_san(x509.get()))
, is_ca(X509_check_ca(x509.get()) != 0)
, hash(compute_digest(x509.get()))
, not_before(asn1_time_to_chrono(X509_get0_notBefore(x509.get())))
, not_after(asn1_time_to_chrono(X509_get0_notAfter(x509.get())))
{
}
ParsedCertificate(const ParsedCertificate& other)
: x509(nullptr, typed_delete<X509>)
, pub_key_id(public_key_algorithm(other.x509.get()))
, sig_algo(signature_algorithm(other.x509.get()))
, issuer_hash(other.issuer_hash)
, subject_hash(other.subject_hash)
, issuer(other.issuer)
, subject(other.subject)
, subject_key_id(other.subject_key_id)
, authority_key_id(other.authority_key_id)
, sub_alt_names(other.sub_alt_names)
, is_ca(other.is_ca)
, hash(other.hash)
, not_before(other.not_before)
, not_after(other.not_after)
{
if (1 != X509_up_ref(other.x509.get())) {
throw openssl_error();
}
x509.reset(other.x509.get());
}
static Signature::ID public_key_algorithm(X509* x509)
{
#if WITH_BORINGSSL
const auto pub = make_typed_unique(X509_get_pubkey(x509));
const auto* pub_ptr = pub.get();
#else
const auto* pub_ptr = X509_get0_pubkey(x509);
#endif
switch (EVP_PKEY_base_id(pub_ptr)) {
case EVP_PKEY_ED25519:
return Signature::ID::Ed25519;
#if !defined(WITH_BORINGSSL)
case EVP_PKEY_ED448:
return Signature::ID::Ed448;
#endif
case EVP_PKEY_EC: {
auto key_size = EVP_PKEY_bits(pub_ptr);
switch (key_size) {
case 256:
return Signature::ID::P256_SHA256;
case 384:
return Signature::ID::P384_SHA384;
case 521:
return Signature::ID::P521_SHA512;
default:
throw std::runtime_error("Unknown curve");
}
}
case EVP_PKEY_RSA:
// RSA public keys are not specific to an algorithm
return Signature::ID::RSA_SHA256;
default:
break;
}
throw std::runtime_error("Unsupported public key algorithm");
}
static Signature::ID signature_algorithm(X509* cert)
{
auto nid = X509_get_signature_nid(cert);
switch (nid) {
case EVP_PKEY_ED25519:
return Signature::ID::Ed25519;
#if !defined(WITH_BORINGSSL)
case EVP_PKEY_ED448:
return Signature::ID::Ed448;
#endif
case NID_ecdsa_with_SHA256:
return Signature::ID::P256_SHA256;
case NID_ecdsa_with_SHA384:
return Signature::ID::P384_SHA384;
case NID_ecdsa_with_SHA512:
return Signature::ID::P521_SHA512;
case NID_sha1WithRSAEncryption:
// We fall through to SHA256 for SHA1 because we do not implement SHA-1.
case NID_sha256WithRSAEncryption:
return Signature::ID::RSA_SHA256;
case NID_sha384WithRSAEncryption:
return Signature::ID::RSA_SHA384;
case NID_sha512WithRSAEncryption:
return Signature::ID::RSA_SHA512;
default:
break;
}
throw std::runtime_error("Unsupported signature algorithm");
}
typed_unique_ptr<EVP_PKEY> public_key() const
{
return make_typed_unique<EVP_PKEY>(X509_get_pubkey(x509.get()));
}
Certificate::ExpirationStatus expiration_status() const
{
auto now = std::chrono::system_clock::now();
if (now < not_before) {
return Certificate::ExpirationStatus::inactive;
}
if (now > not_after) {
return Certificate::ExpirationStatus::expired;
}
return Certificate::ExpirationStatus::active;
}
bytes raw() const
{
auto out = bytes(i2d_X509(x509.get(), nullptr));
auto* ptr = out.data();
i2d_X509(x509.get(), &ptr);
return out;
}
typed_unique_ptr<X509> x509;
const Signature::ID pub_key_id;
const Signature::ID sig_algo;
const uint64_t issuer_hash;
const uint64_t subject_hash;
const ParsedName issuer;
const ParsedName subject;
const std::optional<bytes> subject_key_id;
const std::optional<bytes> authority_key_id;
const std::vector<GeneralName> sub_alt_names;
const bool is_ca;
const bytes hash;
const std::chrono::system_clock::time_point not_before;
const std::chrono::system_clock::time_point not_after;
};
///
/// Certificate
///
static std::unique_ptr<Signature::PublicKey>
signature_key(EVP_PKEY* pkey)
{
switch (EVP_PKEY_base_id(pkey)) {
case EVP_PKEY_RSA:
return std::make_unique<RSASignature::PublicKey>(pkey);
case EVP_PKEY_ED448:
case EVP_PKEY_ED25519:
case EVP_PKEY_EC:
return std::make_unique<EVPGroup::PublicKey>(pkey);
default:
throw std::runtime_error("Unsupported algorithm");
}
}
Certificate::Certificate(std::unique_ptr<ParsedCertificate>&& parsed_cert_in)
: parsed_cert(std::move(parsed_cert_in))
, public_key(signature_key(parsed_cert->public_key().release()))
, raw(parsed_cert->raw())
{
}
Certificate::Certificate(const bytes& der)
: parsed_cert(ParsedCertificate::parse(der))
, public_key(signature_key(parsed_cert->public_key().release()))
, raw(der)
{
}
Certificate::Certificate(const Certificate& other)
: parsed_cert(std::make_unique<ParsedCertificate>(*other.parsed_cert))
, public_key(signature_key(parsed_cert->public_key().release()))
, raw(other.raw)
{
}
Certificate::~Certificate() = default;
std::vector<Certificate>
Certificate::parse_pem(const bytes& pem)
{
auto size_int = static_cast<int>(pem.size());
auto bio = make_typed_unique<BIO>(BIO_new_mem_buf(pem.data(), size_int));
if (!bio) {
throw openssl_error();
}
auto certs = std::vector<Certificate>();
while (true) {
auto x509 = make_typed_unique<X509>(
PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
if (!x509) {
// NOLINTNEXTLINE(hicpp-signed-bitwise)
auto err = ERR_GET_REASON(ERR_peek_last_error());
if (err == PEM_R_NO_START_LINE) {
// No more objects to read
break;
}
throw openssl_error();
}
auto parsed = std::make_unique<ParsedCertificate>(x509.release());
certs.emplace_back(std::move(parsed));
}
return certs;
}
bool
Certificate::valid_from(const Certificate& parent) const
{
auto pub = parent.parsed_cert->public_key();
return (1 == X509_verify(parsed_cert->x509.get(), pub.get()));
}
uint64_t
Certificate::issuer_hash() const
{
return parsed_cert->issuer_hash;
}
uint64_t
Certificate::subject_hash() const
{
return parsed_cert->subject_hash;
}
Certificate::ParsedName
Certificate::subject() const
{
return parsed_cert->subject;
}
Certificate::ParsedName
Certificate::issuer() const
{
return parsed_cert->issuer;
}
bool
Certificate::is_ca() const
{
return parsed_cert->is_ca;
}
Certificate::ExpirationStatus
Certificate::expiration_status() const
{
return parsed_cert->expiration_status();
}
std::optional<bytes>
Certificate::subject_key_id() const
{
return parsed_cert->subject_key_id;
}
std::optional<bytes>
Certificate::authority_key_id() const
{
return parsed_cert->authority_key_id;
}
std::vector<std::string>
Certificate::email_addresses() const
{
std::vector<std::string> emails;
for (const auto& name : parsed_cert->sub_alt_names) {
if (tls::var::holds_alternative<RFC822Name>(name)) {
emails.emplace_back(tls::var::get<RFC822Name>(name).value);
}
}
return emails;
}
std::vector<std::string>
Certificate::dns_names() const
{
std::vector<std::string> domains;
for (const auto& name : parsed_cert->sub_alt_names) {
if (tls::var::holds_alternative<DNSName>(name)) {
domains.emplace_back(tls::var::get<DNSName>(name).value);
}
}
return domains;
}
bytes
Certificate::hash() const
{
return parsed_cert->hash;
}
std::chrono::system_clock::time_point
Certificate::not_before() const
{
return parsed_cert->not_before;
}
std::chrono::system_clock::time_point
Certificate::not_after() const
{
return parsed_cert->not_after;
}
Signature::ID
Certificate::public_key_algorithm() const
{
return parsed_cert->pub_key_id;
}
Signature::ID
Certificate::signature_algorithm() const
{
return parsed_cert->sig_algo;
}
bool
operator==(const Certificate& lhs, const Certificate& rhs)
{
return lhs.raw == rhs.raw;
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,20 @@
#include "common.h"
namespace mlspp::hpke {
bytes
i2osp(uint64_t val, size_t size)
{
auto out = bytes(size, 0);
auto max = size;
if (size > 8) {
max = 8;
}
for (size_t i = 0; i < max; i++) {
out.at(size - i - 1) = static_cast<uint8_t>(val >> (8 * i));
}
return out;
}
} // namespace mlspp::hpke

10
DPP/mlspp/lib/hpke/src/common.h Executable file
View File

@@ -0,0 +1,10 @@
#pragma once
#include <hpke/hpke.h>
namespace mlspp::hpke {
bytes
i2osp(uint64_t val, size_t size);
} // namespace mlspp::hpke

216
DPP/mlspp/lib/hpke/src/dhkem.cpp Executable file
View File

@@ -0,0 +1,216 @@
#include "dhkem.h"
#include "common.h"
namespace mlspp::hpke {
DHKEM::PrivateKey::PrivateKey(Group::PrivateKey* group_priv_in)
: group_priv(group_priv_in)
{
}
std::unique_ptr<KEM::PublicKey>
DHKEM::PrivateKey::public_key() const
{
return group_priv->public_key();
}
DHKEM
make_dhkem(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in)
{
return { kem_id_in, group_in, kdf_in };
}
template<>
const DHKEM&
DHKEM::get<KEM::ID::DHKEM_P256_SHA256>()
{
static const auto instance = make_dhkem(KEM::ID::DHKEM_P256_SHA256,
Group::get<Group::ID::P256>(),
KDF::get<KDF::ID::HKDF_SHA256>());
return instance;
}
template<>
const DHKEM&
DHKEM::get<KEM::ID::DHKEM_P384_SHA384>()
{
static const auto instance = make_dhkem(KEM::ID::DHKEM_P384_SHA384,
Group::get<Group::ID::P384>(),
KDF::get<KDF::ID::HKDF_SHA384>());
return instance;
}
template<>
const DHKEM&
DHKEM::get<KEM::ID::DHKEM_P521_SHA512>()
{
static const auto instance = make_dhkem(KEM::ID::DHKEM_P521_SHA512,
Group::get<Group::ID::P521>(),
KDF::get<KDF::ID::HKDF_SHA512>());
return instance;
}
template<>
const DHKEM&
DHKEM::get<KEM::ID::DHKEM_X25519_SHA256>()
{
static const auto instance = make_dhkem(KEM::ID::DHKEM_X25519_SHA256,
Group::get<Group::ID::X25519>(),
KDF::get<KDF::ID::HKDF_SHA256>());
return instance;
}
#if !defined(WITH_BORINGSSL)
template<>
const DHKEM&
DHKEM::get<KEM::ID::DHKEM_X448_SHA512>()
{
static const auto instance = make_dhkem(KEM::ID::DHKEM_X448_SHA512,
Group::get<Group::ID::X448>(),
KDF::get<KDF::ID::HKDF_SHA512>());
return instance;
}
#endif
DHKEM::DHKEM(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in)
: KEM(kem_id_in,
kdf_in.hash_size,
group_in.pk_size,
group_in.pk_size,
group_in.sk_size)
, group(group_in)
, kdf(kdf_in)
{
static const auto label_kem = from_ascii("KEM");
suite_id = label_kem + i2osp(uint16_t(kem_id_in), 2);
}
std::unique_ptr<KEM::PrivateKey>
DHKEM::generate_key_pair() const
{
return std::make_unique<DHKEM::PrivateKey>(
group.generate_key_pair().release());
}
std::unique_ptr<KEM::PrivateKey>
DHKEM::derive_key_pair(const bytes& ikm) const
{
return std::make_unique<DHKEM::PrivateKey>(
group.derive_key_pair(suite_id, ikm).release());
}
bytes
DHKEM::serialize(const KEM::PublicKey& pk) const
{
const auto& gpk = dynamic_cast<const Group::PublicKey&>(pk);
return group.serialize(gpk);
}
std::unique_ptr<KEM::PublicKey>
DHKEM::deserialize(const bytes& enc) const
{
return group.deserialize(enc);
}
bytes
DHKEM::serialize_private(const KEM::PrivateKey& sk) const
{
const auto& gsk = dynamic_cast<const PrivateKey&>(sk);
return group.serialize_private(*gsk.group_priv);
}
std::unique_ptr<KEM::PrivateKey>
DHKEM::deserialize_private(const bytes& skm) const
{
return std::make_unique<PrivateKey>(group.deserialize_private(skm).release());
}
std::pair<bytes, bytes>
DHKEM::encap(const KEM::PublicKey& pkR) const
{
const auto& gpkR = dynamic_cast<const Group::PublicKey&>(pkR);
auto skE = group.generate_key_pair();
auto pkE = skE->public_key();
auto zz = group.dh(*skE, gpkR);
auto enc = group.serialize(*pkE);
auto pkRm = group.serialize(gpkR);
auto kem_context = enc + pkRm;
auto shared_secret = extract_and_expand(zz, kem_context);
return std::make_pair(shared_secret, enc);
}
bytes
DHKEM::decap(const bytes& enc, const KEM::PrivateKey& skR) const
{
const auto& gskR = dynamic_cast<const PrivateKey&>(skR);
auto pkR = gskR.group_priv->public_key();
auto pkE = group.deserialize(enc);
auto zz = group.dh(*gskR.group_priv, *pkE);
auto pkRm = group.serialize(*pkR);
auto kem_context = enc + pkRm;
return extract_and_expand(zz, kem_context);
}
std::pair<bytes, bytes>
DHKEM::auth_encap(const KEM::PublicKey& pkR, const KEM::PrivateKey& skS) const
{
const auto& gpkR = dynamic_cast<const Group::PublicKey&>(pkR);
const auto& gskS = dynamic_cast<const PrivateKey&>(skS);
auto skE = group.generate_key_pair();
auto pkE = skE->public_key();
auto pkS = gskS.group_priv->public_key();
auto zzER = group.dh(*skE, gpkR);
auto zzSR = group.dh(*gskS.group_priv, gpkR);
auto zz = zzER + zzSR;
auto enc = group.serialize(*pkE);
auto pkRm = group.serialize(gpkR);
auto pkSm = group.serialize(*pkS);
auto kem_context = enc + pkRm + pkSm;
auto shared_secret = extract_and_expand(zz, kem_context);
return std::make_pair(shared_secret, enc);
}
bytes
DHKEM::auth_decap(const bytes& enc,
const KEM::PublicKey& pkS,
const KEM::PrivateKey& skR) const
{
const auto& gpkS = dynamic_cast<const Group::PublicKey&>(pkS);
const auto& gskR = dynamic_cast<const PrivateKey&>(skR);
auto pkE = group.deserialize(enc);
auto pkR = gskR.group_priv->public_key();
auto zzER = group.dh(*gskR.group_priv, *pkE);
auto zzSR = group.dh(*gskR.group_priv, gpkS);
auto zz = zzER + zzSR;
auto pkRm = group.serialize(*pkR);
auto pkSm = group.serialize(gpkS);
auto kem_context = enc + pkRm + pkSm;
return extract_and_expand(zz, kem_context);
}
bytes
DHKEM::extract_and_expand(const bytes& dh, const bytes& kem_context) const
{
static const auto label_eae_prk = from_ascii("eae_prk");
static const auto label_shared_secret = from_ascii("shared_secret");
auto eae_prk = kdf.labeled_extract(suite_id, {}, label_eae_prk, dh);
return kdf.labeled_expand(
suite_id, eae_prk, label_shared_secret, kem_context, secret_size);
}
} // namespace mlspp::hpke

57
DPP/mlspp/lib/hpke/src/dhkem.h Executable file
View File

@@ -0,0 +1,57 @@
#pragma once
#include <hpke/hpke.h>
#include "group.h"
namespace mlspp::hpke {
struct DHKEM : public KEM
{
struct PrivateKey : public KEM::PrivateKey
{
PrivateKey(Group::PrivateKey* group_priv_in);
std::unique_ptr<KEM::PublicKey> public_key() const override;
std::unique_ptr<Group::PrivateKey> group_priv;
};
template<KEM::ID>
static const DHKEM& get();
~DHKEM() override = default;
std::unique_ptr<KEM::PrivateKey> generate_key_pair() const override;
std::unique_ptr<KEM::PrivateKey> derive_key_pair(
const bytes& ikm) const override;
bytes serialize(const KEM::PublicKey& pk) const override;
std::unique_ptr<KEM::PublicKey> deserialize(const bytes& enc) const override;
bytes serialize_private(const KEM::PrivateKey& sk) const override;
std::unique_ptr<KEM::PrivateKey> deserialize_private(
const bytes& skm) const override;
std::pair<bytes, bytes> encap(const KEM::PublicKey& pk) const override;
bytes decap(const bytes& enc, const KEM::PrivateKey& sk) const override;
std::pair<bytes, bytes> auth_encap(const KEM::PublicKey& pkR,
const KEM::PrivateKey& skS) const override;
bytes auth_decap(const bytes& enc,
const KEM::PublicKey& pkS,
const KEM::PrivateKey& skR) const override;
private:
const Group& group;
const KDF& kdf;
bytes suite_id;
bytes extract_and_expand(const bytes& dh, const bytes& kem_context) const;
DHKEM(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in);
friend DHKEM make_dhkem(KEM::ID kem_id_in,
const Group& group_in,
const KDF& kdf_in);
};
} // namespace mlspp::hpke

187
DPP/mlspp/lib/hpke/src/digest.cpp Executable file
View File

@@ -0,0 +1,187 @@
#include <hpke/digest.h>
#include <openssl/evp.h>
#include <openssl/hmac.h>
#if defined(WITH_OPENSSL3)
#include <openssl/core_names.h>
#endif
#include "openssl_common.h"
namespace mlspp::hpke {
static const EVP_MD*
openssl_digest_type(Digest::ID digest)
{
switch (digest) {
case Digest::ID::SHA256:
return EVP_sha256();
case Digest::ID::SHA384:
return EVP_sha384();
case Digest::ID::SHA512:
return EVP_sha512();
default:
throw std::runtime_error("Unsupported ciphersuite");
}
}
#if defined(WITH_OPENSSL3)
static std::string
openssl_digest_name(Digest::ID digest)
{
switch (digest) {
case Digest::ID::SHA256:
return OSSL_DIGEST_NAME_SHA2_256;
case Digest::ID::SHA384:
return OSSL_DIGEST_NAME_SHA2_384;
case Digest::ID::SHA512:
return OSSL_DIGEST_NAME_SHA2_512;
default:
throw std::runtime_error("Unsupported digest algorithm");
}
}
#endif
template<>
const Digest&
Digest::get<Digest::ID::SHA256>()
{
static const Digest instance(Digest::ID::SHA256);
return instance;
}
template<>
const Digest&
Digest::get<Digest::ID::SHA384>()
{
static const Digest instance(Digest::ID::SHA384);
return instance;
}
template<>
const Digest&
Digest::get<Digest::ID::SHA512>()
{
static const Digest instance(Digest::ID::SHA512);
return instance;
}
Digest::Digest(Digest::ID id_in)
: id(id_in)
, hash_size(EVP_MD_size(openssl_digest_type(id_in)))
{
}
bytes
Digest::hash(const bytes& data) const
{
auto md = bytes(hash_size);
unsigned int size = 0;
const auto* type = openssl_digest_type(id);
if (1 !=
EVP_Digest(data.data(), data.size(), md.data(), &size, type, nullptr)) {
throw openssl_error();
}
return md;
}
bytes
Digest::hmac(const bytes& key, const bytes& data) const
{
auto md = bytes(hash_size);
unsigned int size = 0;
const auto* type = openssl_digest_type(id);
if (nullptr == HMAC(type,
key.data(),
static_cast<int>(key.size()),
data.data(),
static_cast<int>(data.size()),
md.data(),
&size)) {
throw openssl_error();
}
return md;
}
bytes
Digest::hmac_for_hkdf_extract(const bytes& key, const bytes& data) const
{
#if defined(WITH_OPENSSL3)
auto digest_name = openssl_digest_name(id);
std::array<OSSL_PARAM, 2> params = {
OSSL_PARAM_construct_utf8_string(
OSSL_ALG_PARAM_DIGEST, digest_name.data(), 0),
OSSL_PARAM_construct_end()
};
const auto mac =
make_typed_unique(EVP_MAC_fetch(nullptr, OSSL_MAC_NAME_HMAC, nullptr));
const auto ctx = make_typed_unique(EVP_MAC_CTX_new(mac.get()));
#else
const auto* type = openssl_digest_type(id);
auto ctx = make_typed_unique(HMAC_CTX_new());
#endif
if (ctx == nullptr) {
throw openssl_error();
}
// Some FIPS-enabled libraries are overly conservative in their interpretation
// of NIST SP 800-131A, which requires HMAC keys to be at least 112 bits long.
// That document does not impose that requirement on HKDF, so we disable FIPS
// enforcement for purposes of HKDF.
//
// https://doi.org/10.6028/NIST.SP.800-131Ar2
auto key_size = static_cast<int>(key.size());
// OpenSSL 3 does not support the flag EVP_MD_CTX_FLAG_NON_FIPS_ALLOW anymore.
// However, OpenSSL 3 in FIPS mode doesn't seem to check the HMAC key size
// constraint.
#if !defined(WITH_OPENSSL3) && !defined(WITH_BORINGSSL)
static const auto fips_min_hmac_key_len = 14;
if (FIPS_mode() != 0 && key_size < fips_min_hmac_key_len) {
HMAC_CTX_set_flags(ctx.get(), EVP_MD_CTX_FLAG_NON_FIPS_ALLOW);
}
#endif
// Guard against sending nullptr to HMAC_Init_ex
const auto* key_data = key.data();
const auto non_null_zero_length_key = uint8_t(0);
if (key_data == nullptr) {
key_data = &non_null_zero_length_key;
}
auto md = bytes(hash_size);
#if defined(WITH_OPENSSL3)
if (1 != EVP_MAC_init(ctx.get(), key_data, key_size, params.data())) {
throw openssl_error();
}
if (1 != EVP_MAC_update(ctx.get(), data.data(), data.size())) {
throw openssl_error();
}
size_t size = 0;
if (1 != EVP_MAC_final(ctx.get(), md.data(), &size, hash_size)) {
throw openssl_error();
}
#else
if (1 != HMAC_Init_ex(ctx.get(), key_data, key_size, type, nullptr)) {
throw openssl_error();
}
if (1 != HMAC_Update(ctx.get(), data.data(), data.size())) {
throw openssl_error();
}
unsigned int size = 0;
if (1 != HMAC_Final(ctx.get(), md.data(), &size)) {
throw openssl_error();
}
#endif
return md;
}
} // namespace mlspp::hpke

1077
DPP/mlspp/lib/hpke/src/group.cpp Executable file

File diff suppressed because it is too large Load Diff

116
DPP/mlspp/lib/hpke/src/group.h Executable file
View File

@@ -0,0 +1,116 @@
#pragma once
#include <hpke/hpke.h>
#include <hpke/signature.h>
#include "openssl_common.h"
#include <openssl/evp.h>
namespace mlspp::hpke {
struct Group
{
enum struct ID : uint8_t
{
P256,
P384,
P521,
X25519,
X448,
Ed25519,
Ed448,
};
struct PublicKey
: public KEM::PublicKey
, public Signature::PublicKey
{
virtual ~PublicKey() = default;
};
struct PrivateKey
{
virtual ~PrivateKey() = default;
virtual std::unique_ptr<PublicKey> public_key() const = 0;
};
template<Group::ID id>
static const Group& get();
virtual ~Group() = default;
const ID id;
const size_t dh_size;
const size_t pk_size;
const size_t sk_size;
const std::string jwk_key_type;
const std::string jwk_curve_name;
virtual std::unique_ptr<PrivateKey> generate_key_pair() const = 0;
virtual std::unique_ptr<PrivateKey> derive_key_pair(
const bytes& suite_id,
const bytes& ikm) const = 0;
virtual bytes serialize(const PublicKey& pk) const = 0;
virtual std::unique_ptr<PublicKey> deserialize(const bytes& enc) const = 0;
virtual bytes serialize_private(const PrivateKey& sk) const = 0;
virtual std::unique_ptr<PrivateKey> deserialize_private(
const bytes& skm) const = 0;
virtual bytes dh(const PrivateKey& sk, const PublicKey& pk) const = 0;
virtual bytes sign(const bytes& data, const PrivateKey& sk) const = 0;
virtual bool verify(const bytes& data,
const bytes& sig,
const PublicKey& pk) const = 0;
virtual std::tuple<bytes, bytes> coordinates(const PublicKey& pk) const = 0;
virtual std::unique_ptr<PublicKey> public_key_from_coordinates(
const bytes& x,
const bytes& y) const = 0;
protected:
const KDF& kdf;
friend struct DHKEM;
Group(ID group_id_in, const KDF& kdf_in);
};
struct EVPGroup : public Group
{
EVPGroup(Group::ID group_id, const KDF& kdf);
struct PublicKey : public Group::PublicKey
{
explicit PublicKey(EVP_PKEY* pkey_in);
~PublicKey() override = default;
// NOLINTNEXTLINE(misc-non-private-member-variables-in-classes)
typed_unique_ptr<EVP_PKEY> pkey;
};
struct PrivateKey : public Group::PrivateKey
{
explicit PrivateKey(EVP_PKEY* pkey_in);
~PrivateKey() override = default;
std::unique_ptr<Group::PublicKey> public_key() const override;
// NOLINTNEXTLINE(misc-non-private-member-variables-in-classes)
typed_unique_ptr<EVP_PKEY> pkey;
};
std::unique_ptr<Group::PrivateKey> generate_key_pair() const override;
bytes dh(const Group::PrivateKey& sk,
const Group::PublicKey& pk) const override;
bytes sign(const bytes& data, const Group::PrivateKey& sk) const override;
bool verify(const bytes& data,
const bytes& sig,
const Group::PublicKey& pk) const override;
};
} // namespace mlspp::hpke

79
DPP/mlspp/lib/hpke/src/hkdf.cpp Executable file
View File

@@ -0,0 +1,79 @@
#include "hkdf.h"
#include "openssl_common.h"
#include <openssl/err.h>
#include <openssl/evp.h>
#include <stdexcept>
namespace mlspp::hpke {
template<>
const HKDF&
HKDF::get<Digest::ID::SHA256>()
{
static const HKDF instance(Digest::get<Digest::ID::SHA256>());
return instance;
}
template<>
const HKDF&
HKDF::get<Digest::ID::SHA384>()
{
static const HKDF instance(Digest::get<Digest::ID::SHA384>());
return instance;
}
template<>
const HKDF&
HKDF::get<Digest::ID::SHA512>()
{
static const HKDF instance(Digest::get<Digest::ID::SHA512>());
return instance;
}
static KDF::ID
digest_to_kdf(Digest::ID digest_id)
{
switch (digest_id) {
case Digest::ID::SHA256:
return KDF::ID::HKDF_SHA256;
case Digest::ID::SHA384:
return KDF::ID::HKDF_SHA384;
case Digest::ID::SHA512:
return KDF::ID::HKDF_SHA512;
}
throw std::runtime_error("Unsupported algorithm");
}
HKDF::HKDF(const Digest& digest_in)
: KDF(digest_to_kdf(digest_in.id), digest_in.hash_size)
, digest(digest_in)
{
}
bytes
HKDF::extract(const bytes& salt, const bytes& ikm) const
{
return digest.hmac_for_hkdf_extract(salt, ikm);
}
bytes
HKDF::expand(const bytes& prk, const bytes& info, size_t size) const
{
auto okm = bytes{};
auto i = uint8_t(0x00);
auto Ti = bytes{};
while (okm.size() < size) {
i += 1;
auto block = Ti + info + bytes{ i };
Ti = digest.hmac(prk, block);
okm += Ti;
}
okm.resize(size);
return okm;
}
} // namespace mlspp::hpke

24
DPP/mlspp/lib/hpke/src/hkdf.h Executable file
View File

@@ -0,0 +1,24 @@
#pragma once
#include <hpke/digest.h>
#include <hpke/hpke.h>
namespace mlspp::hpke {
struct HKDF : public KDF
{
template<Digest::ID digest_id>
static const HKDF& get();
~HKDF() override = default;
bytes extract(const bytes& salt, const bytes& ikm) const override;
bytes expand(const bytes& prk, const bytes& info, size_t size) const override;
private:
const Digest& digest;
explicit HKDF(const Digest& digest_in);
};
} // namespace mlspp::hpke

540
DPP/mlspp/lib/hpke/src/hpke.cpp Executable file
View File

@@ -0,0 +1,540 @@
#include <hpke/digest.h>
#include <hpke/hpke.h>
#include "aead_cipher.h"
#include "common.h"
#include "dhkem.h"
#include "hkdf.h"
#include <limits>
#include <stdexcept>
#include <string>
namespace mlspp::hpke {
///
/// Helper functions and constants
///
static const bytes&
label_exp()
{
static const bytes val = from_ascii("exp");
return val;
}
static const bytes&
label_hpke()
{
static const bytes val = from_ascii("HPKE");
return val;
}
static const bytes&
label_hpke_version()
{
static const bytes val = from_ascii("HPKE-v1");
return val;
}
static const bytes&
label_info_hash()
{
static const bytes val = from_ascii("info_hash");
return val;
}
static const bytes&
label_key()
{
static const bytes val = from_ascii("key");
return val;
}
static const bytes&
label_base_nonce()
{
static const bytes val = from_ascii("base_nonce");
return val;
}
static const bytes&
label_psk_id_hash()
{
static const bytes val = from_ascii("psk_id_hash");
return val;
}
static const bytes&
label_sec()
{
static const bytes val = from_ascii("sec");
return val;
}
static const bytes&
label_secret()
{
static const bytes val = from_ascii("secret");
return val;
}
///
/// Factory methods for primitives
///
KEM::KEM(ID id_in,
size_t secret_size_in,
size_t enc_size_in,
size_t pk_size_in,
size_t sk_size_in)
: id(id_in)
, secret_size(secret_size_in)
, enc_size(enc_size_in)
, pk_size(pk_size_in)
, sk_size(sk_size_in)
{
}
template<>
const KEM&
KEM::get<KEM::ID::DHKEM_P256_SHA256>()
{
return DHKEM::get<KEM::ID::DHKEM_P256_SHA256>();
}
template<>
const KEM&
KEM::get<KEM::ID::DHKEM_P384_SHA384>()
{
return DHKEM::get<KEM::ID::DHKEM_P384_SHA384>();
}
template<>
const KEM&
KEM::get<KEM::ID::DHKEM_P521_SHA512>()
{
return DHKEM::get<KEM::ID::DHKEM_P521_SHA512>();
}
template<>
const KEM&
KEM::get<KEM::ID::DHKEM_X25519_SHA256>()
{
return DHKEM::get<KEM::ID::DHKEM_X25519_SHA256>();
}
#if !defined(WITH_BORINGSSL)
template<>
const KEM&
KEM::get<KEM::ID::DHKEM_X448_SHA512>()
{
return DHKEM::get<KEM::ID::DHKEM_X448_SHA512>();
}
#endif
bytes
KEM::serialize_private(const KEM::PrivateKey& /* unused */) const
{
throw std::runtime_error("Not implemented");
}
std::unique_ptr<KEM::PrivateKey>
KEM::deserialize_private(const bytes& /* unused */) const
{
throw std::runtime_error("Not implemented");
}
std::pair<bytes, bytes>
KEM::auth_encap(const PublicKey& /* unused */,
const PrivateKey& /* unused */) const
{
throw std::runtime_error("Not implemented");
}
bytes
KEM::auth_decap(const bytes& /* unused */,
const PublicKey& /* unused */,
const PrivateKey& /* unused */) const
{
throw std::runtime_error("Not implemented");
}
template<>
const KDF&
KDF::get<KDF::ID::HKDF_SHA256>()
{
return HKDF::get<Digest::ID::SHA256>();
}
template<>
const KDF&
KDF::get<KDF::ID::HKDF_SHA384>()
{
return HKDF::get<Digest::ID::SHA384>();
}
template<>
const KDF&
KDF::get<KDF::ID::HKDF_SHA512>()
{
return HKDF::get<Digest::ID::SHA512>();
}
KDF::KDF(ID id_in, size_t hash_size_in)
: id(id_in)
, hash_size(hash_size_in)
{
}
bytes
KDF::labeled_extract(const bytes& suite_id,
const bytes& salt,
const bytes& label,
const bytes& ikm) const
{
auto labeled_ikm = label_hpke_version() + suite_id + label + ikm;
return extract(salt, labeled_ikm);
}
bytes
KDF::labeled_expand(const bytes& suite_id,
const bytes& prk,
const bytes& label,
const bytes& info,
size_t size) const
{
auto labeled_info =
i2osp(size, 2) + label_hpke_version() + suite_id + label + info;
return expand(prk, labeled_info, size);
}
template<>
const AEAD&
AEAD::get<AEAD::ID::AES_128_GCM>()
{
return AEADCipher::get<AEAD::ID::AES_128_GCM>();
}
template<>
const AEAD&
AEAD::get<AEAD::ID::AES_256_GCM>()
{
return AEADCipher::get<AEAD::ID::AES_256_GCM>();
}
template<>
const AEAD&
AEAD::get<AEAD::ID::CHACHA20_POLY1305>()
{
return AEADCipher::get<AEAD::ID::CHACHA20_POLY1305>();
}
template<>
const AEAD&
AEAD::get<AEAD::ID::export_only>()
{
static const auto export_only = ExportOnlyCipher{};
return export_only;
}
AEAD::AEAD(ID id_in, size_t key_size_in, size_t nonce_size_in)
: id(id_in)
, key_size(key_size_in)
, nonce_size(nonce_size_in)
{
}
///
/// Encryption Contexts
///
bytes
Context::do_export(const bytes& exporter_context, size_t size) const
{
return kdf.labeled_expand(
suite, exporter_secret, label_sec(), exporter_context, size);
}
bytes
Context::current_nonce() const
{
auto curr = i2osp(seq, aead.nonce_size);
return curr ^ nonce;
}
void
Context::increment_seq()
{
if (seq == std::numeric_limits<uint64_t>::max()) {
throw std::runtime_error("Sequence number overflow");
}
seq += 1;
}
Context::Context(bytes suite_in,
bytes key_in,
bytes nonce_in,
bytes exporter_secret_in,
const KDF& kdf_in,
const AEAD& aead_in)
: suite(std::move(suite_in))
, key(std::move(key_in))
, nonce(std::move(nonce_in))
, exporter_secret(std::move(exporter_secret_in))
, kdf(kdf_in)
, aead(aead_in)
, seq(0)
{
}
bool
operator==(const Context& lhs, const Context& rhs)
{
// TODO(RLB) Compare KDF and AEAD algorithms
auto suite = (lhs.suite == rhs.suite);
auto key = (lhs.key == rhs.key);
auto nonce = (lhs.nonce == rhs.nonce);
auto exporter_secret = (lhs.exporter_secret == rhs.exporter_secret);
auto seq = (lhs.seq == rhs.seq);
return suite && key && nonce && exporter_secret && seq;
}
SenderContext::SenderContext(Context&& c)
: Context(std::move(c))
{
}
bytes
SenderContext::seal(const bytes& aad, const bytes& pt)
{
auto ct = aead.seal(key, current_nonce(), aad, pt);
increment_seq();
return ct;
}
ReceiverContext::ReceiverContext(Context&& c)
: Context(std::move(c))
{
}
std::optional<bytes>
ReceiverContext::open(const bytes& aad, const bytes& ct)
{
auto maybe_pt = aead.open(key, current_nonce(), aad, ct);
increment_seq();
return maybe_pt;
}
///
/// HPKE
///
static const bytes default_psk = {};
static const bytes default_psk_id = {};
static bytes
suite_id(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id)
{
return label_hpke() + i2osp(static_cast<uint64_t>(kem_id), 2) +
i2osp(static_cast<uint64_t>(kdf_id), 2) +
i2osp(static_cast<uint64_t>(aead_id), 2);
}
static const KEM&
select_kem(KEM::ID id)
{
switch (id) {
case KEM::ID::DHKEM_P256_SHA256:
return KEM::get<KEM::ID::DHKEM_P256_SHA256>();
case KEM::ID::DHKEM_P384_SHA384:
return KEM::get<KEM::ID::DHKEM_P384_SHA384>();
case KEM::ID::DHKEM_P521_SHA512:
return KEM::get<KEM::ID::DHKEM_P521_SHA512>();
case KEM::ID::DHKEM_X25519_SHA256:
return KEM::get<KEM::ID::DHKEM_X25519_SHA256>();
#if !defined(WITH_BORINGSSL)
case KEM::ID::DHKEM_X448_SHA512:
return KEM::get<KEM::ID::DHKEM_X448_SHA512>();
#endif
default:
throw std::runtime_error("Unsupported algorithm");
}
}
static const KDF&
select_kdf(KDF::ID id)
{
switch (id) {
case KDF::ID::HKDF_SHA256:
return KDF::get<KDF::ID::HKDF_SHA256>();
case KDF::ID::HKDF_SHA384:
return KDF::get<KDF::ID::HKDF_SHA384>();
case KDF::ID::HKDF_SHA512:
return KDF::get<KDF::ID::HKDF_SHA512>();
default:
throw std::runtime_error("Unsupported algorithm");
}
}
static const AEAD&
select_aead(AEAD::ID id)
{
switch (id) {
case AEAD::ID::AES_128_GCM:
return AEAD::get<AEAD::ID::AES_128_GCM>();
case AEAD::ID::AES_256_GCM:
return AEAD::get<AEAD::ID::AES_256_GCM>();
case AEAD::ID::CHACHA20_POLY1305:
return AEAD::get<AEAD::ID::CHACHA20_POLY1305>();
case AEAD::ID::export_only:
return AEAD::get<AEAD::ID::export_only>();
default:
throw std::runtime_error("Unsupported algorithm");
}
}
HPKE::HPKE(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id)
: suite(suite_id(kem_id, kdf_id, aead_id))
, kem(select_kem(kem_id))
, kdf(select_kdf(kdf_id))
, aead(select_aead(aead_id))
{
}
HPKE::SenderInfo
HPKE::setup_base_s(const KEM::PublicKey& pkR, const bytes& info) const
{
auto [shared_secret, enc] = kem.encap(pkR);
auto ctx =
key_schedule(Mode::base, shared_secret, info, default_psk, default_psk_id);
return std::make_pair(enc, SenderContext(std::move(ctx)));
}
ReceiverContext
HPKE::setup_base_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info) const
{
auto pkRm = kem.serialize(*skR.public_key());
auto shared_secret = kem.decap(enc, skR);
auto ctx =
key_schedule(Mode::base, shared_secret, info, default_psk, default_psk_id);
return { std::move(ctx) };
}
HPKE::SenderInfo
HPKE::setup_psk_s(const KEM::PublicKey& pkR,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const
{
auto [shared_secret, enc] = kem.encap(pkR);
auto ctx = key_schedule(Mode::psk, shared_secret, info, psk, psk_id);
return std::make_pair(enc, SenderContext(std::move(ctx)));
}
ReceiverContext
HPKE::setup_psk_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const
{
auto shared_secret = kem.decap(enc, skR);
auto ctx = key_schedule(Mode::psk, shared_secret, info, psk, psk_id);
return { std::move(ctx) };
}
HPKE::SenderInfo
HPKE::setup_auth_s(const KEM::PublicKey& pkR,
const bytes& info,
const KEM::PrivateKey& skS) const
{
auto [shared_secret, enc] = kem.auth_encap(pkR, skS);
auto ctx =
key_schedule(Mode::auth, shared_secret, info, default_psk, default_psk_id);
return std::make_pair(enc, SenderContext(std::move(ctx)));
}
ReceiverContext
HPKE::setup_auth_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const KEM::PublicKey& pkS) const
{
auto shared_secret = kem.auth_decap(enc, pkS, skR);
auto ctx =
key_schedule(Mode::auth, shared_secret, info, default_psk, default_psk_id);
return { std::move(ctx) };
}
HPKE::SenderInfo
HPKE::setup_auth_psk_s(const KEM::PublicKey& pkR,
const bytes& info,
const bytes& psk,
const bytes& psk_id,
const KEM::PrivateKey& skS) const
{
auto [shared_secret, enc] = kem.auth_encap(pkR, skS);
auto ctx = key_schedule(Mode::auth_psk, shared_secret, info, psk, psk_id);
return std::make_pair(enc, SenderContext(std::move(ctx)));
}
ReceiverContext
HPKE::setup_auth_psk_r(const bytes& enc,
const KEM::PrivateKey& skR,
const bytes& info,
const bytes& psk,
const bytes& psk_id,
const KEM::PublicKey& pkS) const
{
auto shared_secret = kem.auth_decap(enc, pkS, skR);
auto ctx = key_schedule(Mode::auth_psk, shared_secret, info, psk, psk_id);
return { std::move(ctx) };
}
bool
HPKE::verify_psk_inputs(Mode mode, const bytes& psk, const bytes& psk_id)
{
auto got_psk = (psk != default_psk);
auto got_psk_id = (psk_id != default_psk_id);
if (got_psk != got_psk_id) {
return false;
}
return (!got_psk && (mode == Mode::base || mode == Mode::auth)) ||
(got_psk && (mode == Mode::psk || mode == Mode::auth_psk));
}
Context
HPKE::key_schedule(Mode mode,
const bytes& shared_secret,
const bytes& info,
const bytes& psk,
const bytes& psk_id) const
{
if (!verify_psk_inputs(mode, psk, psk_id)) {
throw std::runtime_error("Invalid PSK inputs");
}
auto psk_id_hash =
kdf.labeled_extract(suite, {}, label_psk_id_hash(), psk_id);
auto info_hash = kdf.labeled_extract(suite, {}, label_info_hash(), info);
auto mode_bytes = bytes{ uint8_t(mode) };
auto key_schedule_context = mode_bytes + psk_id_hash + info_hash;
auto secret = kdf.labeled_extract(suite, shared_secret, label_secret(), psk);
auto key = kdf.labeled_expand(
suite, secret, label_key(), key_schedule_context, aead.key_size);
auto nonce = kdf.labeled_expand(
suite, secret, label_base_nonce(), key_schedule_context, aead.nonce_size);
auto exporter_secret = kdf.labeled_expand(
suite, secret, label_exp(), key_schedule_context, kdf.hash_size);
return { suite, key, nonce, exporter_secret, kdf, aead };
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,160 @@
#include "openssl_common.h"
#include <openssl/ec.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/hmac.h>
#include <openssl/x509.h>
#include <openssl/x509v3.h>
#if defined(WITH_OPENSSL3)
#include <openssl/param_build.h>
#endif
namespace mlspp::hpke {
template<>
void
typed_delete(EVP_CIPHER_CTX* ptr)
{
EVP_CIPHER_CTX_free(ptr);
}
#if WITH_BORINGSSL
template<>
void
typed_delete(EVP_AEAD_CTX* ptr)
{
EVP_AEAD_CTX_free(ptr);
}
#endif
template<>
void
typed_delete(EVP_PKEY_CTX* ptr)
{
EVP_PKEY_CTX_free(ptr);
}
template<>
void
typed_delete(EVP_MD_CTX* ptr)
{
EVP_MD_CTX_free(ptr);
}
#if !defined(WITH_OPENSSL3)
template<>
void
typed_delete(HMAC_CTX* ptr)
{
HMAC_CTX_free(ptr);
}
#endif
template<>
void
typed_delete(EVP_PKEY* ptr)
{
EVP_PKEY_free(ptr);
}
template<>
void
typed_delete(BIGNUM* ptr)
{
BN_free(ptr);
}
template<>
void
typed_delete(EC_POINT* ptr)
{
EC_POINT_free(ptr);
}
#if !defined(WITH_OPENSSL3)
template<>
void
typed_delete(EC_KEY* ptr)
{
EC_KEY_free(ptr);
}
#endif
#if defined(WITH_OPENSSL3)
template<>
void
typed_delete(EVP_MAC* ptr)
{
EVP_MAC_free(ptr);
}
template<>
void
typed_delete(EVP_MAC_CTX* ptr)
{
EVP_MAC_CTX_free(ptr);
}
template<>
void
typed_delete(EC_GROUP* ptr)
{
EC_GROUP_free(ptr);
}
template<>
void
typed_delete(OSSL_PARAM_BLD* ptr)
{
OSSL_PARAM_BLD_free(ptr);
}
template<>
void
typed_delete(OSSL_PARAM* ptr)
{
OSSL_PARAM_free(ptr);
}
#endif
template<>
void
typed_delete(X509* ptr)
{
X509_free(ptr);
}
template<>
void
typed_delete(STACK_OF(GENERAL_NAME) * ptr)
{
sk_GENERAL_NAME_pop_free(ptr, GENERAL_NAME_free);
}
template<>
void
typed_delete(BIO* ptr)
{
BIO_vfree(ptr);
}
template<>
void
typed_delete(ASN1_TIME* ptr)
{
ASN1_TIME_free(ptr);
}
///
/// Map OpenSSL errors to C++ exceptions
///
std::runtime_error
openssl_error()
{
auto code = ERR_get_error();
return std::runtime_error(ERR_error_string(code, nullptr));
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,25 @@
#pragma once
#include <hpke/hpke.h>
#include <stdexcept>
namespace mlspp::hpke {
template<typename T>
void
typed_delete(T* ptr);
template<typename T>
using typed_unique_ptr = std::unique_ptr<T, decltype(&typed_delete<T>)>;
template<typename T>
typed_unique_ptr<T>
make_typed_unique(T* ptr)
{
return typed_unique_ptr<T>(ptr, typed_delete<T>);
}
std::runtime_error
openssl_error();
} // namespace mlspp::hpke

View File

@@ -0,0 +1,19 @@
#include <hpke/random.h>
#include "openssl_common.h"
#include <openssl/rand.h>
namespace mlspp::hpke {
bytes
random_bytes(size_t size)
{
auto rand = bytes(size);
if (1 != RAND_bytes(rand.data(), static_cast<int>(size))) {
throw openssl_error();
}
return rand;
}
} // namespace mlspp::hpke

207
DPP/mlspp/lib/hpke/src/rsa.cpp Executable file
View File

@@ -0,0 +1,207 @@
#include "rsa.h"
#include "common.h"
#include "openssl/rsa.h"
#include "openssl_common.h"
namespace mlspp::hpke {
std::unique_ptr<Signature::PrivateKey>
RSASignature::generate_key_pair() const
{
throw std::runtime_error("Not implemented");
}
std::unique_ptr<Signature::PrivateKey>
RSASignature::derive_key_pair(const bytes& /*ikm*/) const
{
throw std::runtime_error("Not implemented");
}
std::unique_ptr<Signature::PrivateKey>
RSASignature::generate_key_pair(size_t bits)
{
auto ctx = make_typed_unique(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr));
if (ctx == nullptr) {
throw openssl_error();
}
if (EVP_PKEY_keygen_init(ctx.get()) <= 0) {
throw openssl_error();
}
// NOLINTNEXTLINE(hicpp-signed-bitwise)
if (EVP_PKEY_CTX_set_rsa_keygen_bits(ctx.get(), static_cast<int>(bits)) <=
0) {
throw openssl_error();
}
auto* pkey = static_cast<EVP_PKEY*>(nullptr);
if (EVP_PKEY_keygen(ctx.get(), &pkey) <= 0) {
throw openssl_error();
}
return std::make_unique<PrivateKey>(pkey);
}
// TODO(rlb): Implement derive() with sizes
bytes
RSASignature::serialize(const Signature::PublicKey& pk) const
{
const auto& rpk = dynamic_cast<const PublicKey&>(pk);
const int len = i2d_PublicKey(rpk.pkey.get(), nullptr);
auto raw = bytes(len);
auto* data_ptr = raw.data();
if (len != i2d_PublicKey(rpk.pkey.get(), &data_ptr)) {
throw openssl_error();
}
return raw;
}
std::unique_ptr<Signature::PublicKey>
RSASignature::deserialize(const bytes& enc) const
{
const auto* data_ptr = enc.data();
auto* pkey = d2i_PublicKey(
EVP_PKEY_RSA, nullptr, &data_ptr, static_cast<int>(enc.size()));
if (pkey == nullptr) {
throw openssl_error();
}
return std::make_unique<RSASignature::PublicKey>(pkey);
}
bytes
RSASignature::serialize_private(const Signature::PrivateKey& sk) const
{
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
const int len = i2d_PrivateKey(rsk.pkey.get(), nullptr);
auto raw = bytes(len);
auto* data_ptr = raw.data();
if (len != i2d_PrivateKey(rsk.pkey.get(), &data_ptr)) {
throw openssl_error();
}
return raw;
}
std::unique_ptr<Signature::PrivateKey>
RSASignature::deserialize_private(const bytes& skm) const
{
const auto* data_ptr = skm.data();
auto* pkey = d2i_PrivateKey(
EVP_PKEY_RSA, nullptr, &data_ptr, static_cast<int>(skm.size()));
if (pkey == nullptr) {
throw openssl_error();
}
return std::make_unique<RSASignature::PrivateKey>(pkey);
}
bytes
RSASignature::sign(const bytes& data, const Signature::PrivateKey& sk) const
{
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
auto ctx = make_typed_unique(EVP_MD_CTX_create());
if (ctx == nullptr) {
throw openssl_error();
}
if (1 !=
EVP_DigestSignInit(ctx.get(), nullptr, md, nullptr, rsk.pkey.get())) {
throw openssl_error();
}
size_t siglen = EVP_PKEY_size(rsk.pkey.get());
bytes sig(siglen);
if (1 != EVP_DigestSign(
ctx.get(), sig.data(), &siglen, data.data(), data.size())) {
throw openssl_error();
}
sig.resize(siglen);
return sig;
}
bool
RSASignature::verify(const bytes& data,
const bytes& sig,
const Signature::PublicKey& pk) const
{
const auto& rpk = dynamic_cast<const PublicKey&>(pk);
auto ctx = make_typed_unique(EVP_MD_CTX_create());
if (ctx == nullptr) {
throw openssl_error();
}
if (1 !=
EVP_DigestVerifyInit(ctx.get(), nullptr, md, nullptr, rpk.pkey.get())) {
throw openssl_error();
}
auto rv = EVP_DigestVerify(
ctx.get(), sig.data(), sig.size(), data.data(), data.size());
return rv == 1;
}
// TODO(RLB) Implement these methods. No concrete need, but might be nice for
// completeness.
std::unique_ptr<Signature::PrivateKey>
RSASignature::import_jwk_private(const std::string& /* json_str */) const
{
throw std::runtime_error("not implemented");
}
std::unique_ptr<Signature::PublicKey>
RSASignature::import_jwk(const std::string& /* json_str */) const
{
throw std::runtime_error("not implemented");
}
std::string
RSASignature::export_jwk_private(const Signature::PrivateKey& /* sk */) const
{
throw std::runtime_error("not implemented");
}
std::string
RSASignature::export_jwk(const Signature::PublicKey& /* pk */) const
{
throw std::runtime_error("not implemented");
}
const EVP_MD*
RSASignature::digest_to_md(Digest::ID digest)
{
// NOLINTNEXTLINE(hicpp-multiway-paths-covered)
switch (digest) {
case Digest::ID::SHA256:
return EVP_sha256();
case Digest::ID::SHA384:
return EVP_sha384();
case Digest::ID::SHA512:
return EVP_sha512();
default:
throw std::runtime_error("Unsupported digest");
}
}
Signature::ID
RSASignature::digest_to_sig(Digest::ID digest)
{
// NOLINTNEXTLINE(hicpp-multiway-paths-covered)
switch (digest) {
case Digest::ID::SHA256:
return Signature::ID::RSA_SHA256;
case Digest::ID::SHA384:
return Signature::ID::RSA_SHA384;
case Digest::ID::SHA512:
return Signature::ID::RSA_SHA512;
default:
throw std::runtime_error("Unsupported digest");
}
}
} // namespace mlspp::hpke

97
DPP/mlspp/lib/hpke/src/rsa.h Executable file
View File

@@ -0,0 +1,97 @@
#pragma once
#include <hpke/digest.h>
#include <hpke/hpke.h>
#include <hpke/signature.h>
#include "openssl_common.h"
#include <openssl/evp.h>
#include <openssl/rsa.h>
namespace mlspp::hpke {
// XXX(RLB): There is a lot of code in RSASignature that is duplicated in
// EVPGroup. I have allowed this duplication rather than factoring it out
// because I would like to be able to cleanly remove RSA later.
struct RSASignature : public Signature
{
struct PublicKey : public Signature::PublicKey
{
explicit PublicKey(EVP_PKEY* pkey_in)
: pkey(pkey_in, typed_delete<EVP_PKEY>)
{
}
~PublicKey() override = default;
typed_unique_ptr<EVP_PKEY> pkey;
};
struct PrivateKey : public Signature::PrivateKey
{
explicit PrivateKey(EVP_PKEY* pkey_in)
: pkey(pkey_in, typed_delete<EVP_PKEY>)
{
}
~PrivateKey() override = default;
std::unique_ptr<Signature::PublicKey> public_key() const override
{
if (1 != EVP_PKEY_up_ref(pkey.get())) {
throw openssl_error();
}
return std::make_unique<PublicKey>(pkey.get());
}
typed_unique_ptr<EVP_PKEY> pkey;
};
explicit RSASignature(Digest::ID digest)
: Signature(digest_to_sig(digest))
, md(digest_to_md(digest))
{
}
std::unique_ptr<Signature::PrivateKey> generate_key_pair() const override;
std::unique_ptr<Signature::PrivateKey> derive_key_pair(
const bytes& /*ikm*/) const override;
static std::unique_ptr<Signature::PrivateKey> generate_key_pair(size_t bits);
// TODO(rlb): Implement derive() with sizes
bytes serialize(const Signature::PublicKey& pk) const override;
std::unique_ptr<Signature::PublicKey> deserialize(
const bytes& enc) const override;
bytes serialize_private(const Signature::PrivateKey& sk) const override;
std::unique_ptr<Signature::PrivateKey> deserialize_private(
const bytes& skm) const override;
bytes sign(const bytes& data, const Signature::PrivateKey& sk) const override;
bool verify(const bytes& data,
const bytes& sig,
const Signature::PublicKey& pk) const override;
std::unique_ptr<Signature::PrivateKey> import_jwk_private(
const std::string& json_str) const override;
std::unique_ptr<Signature::PublicKey> import_jwk(
const std::string& json_str) const override;
std::string export_jwk_private(
const Signature::PrivateKey& sk) const override;
std::string export_jwk(const Signature::PublicKey& pk) const override;
private:
const EVP_MD* md;
static const EVP_MD* digest_to_md(Digest::ID digest);
static Signature::ID digest_to_sig(Digest::ID digest);
};
} // namespace mlspp::hpke

View File

@@ -0,0 +1,344 @@
#include <hpke/base64.h>
#include <hpke/digest.h>
#include <hpke/signature.h>
#include <string>
#include "dhkem.h"
#include "rsa.h"
#include <dpp/json.h>
#include <openssl/bio.h>
#include <openssl/bn.h>
#include <openssl/ec.h>
#include <openssl/evp.h>
#include <openssl/rsa.h>
using nlohmann::json;
namespace mlspp::hpke {
struct GroupSignature : public Signature
{
struct PrivateKey : public Signature::PrivateKey
{
explicit PrivateKey(Group::PrivateKey* group_priv_in)
: group_priv(group_priv_in)
{
}
std::unique_ptr<Signature::PublicKey> public_key() const override
{
return group_priv->public_key();
}
std::unique_ptr<Group::PrivateKey> group_priv;
};
static Signature::ID group_to_sig(Group::ID group_id)
{
switch (group_id) {
case Group::ID::P256:
return Signature::ID::P256_SHA256;
case Group::ID::P384:
return Signature::ID::P384_SHA384;
case Group::ID::P521:
return Signature::ID::P521_SHA512;
case Group::ID::Ed25519:
return Signature::ID::Ed25519;
#if !defined(WITH_BORINGSSL)
case Group::ID::Ed448:
return Signature::ID::Ed448;
#endif
default:
throw std::runtime_error("Unsupported group");
}
}
explicit GroupSignature(const Group& group_in)
: Signature(group_to_sig(group_in.id))
, group(group_in)
{
}
std::unique_ptr<Signature::PrivateKey> generate_key_pair() const override
{
return std::make_unique<PrivateKey>(group.generate_key_pair().release());
}
std::unique_ptr<Signature::PrivateKey> derive_key_pair(
const bytes& ikm) const override
{
return std::make_unique<PrivateKey>(
group.derive_key_pair({}, ikm).release());
}
bytes serialize(const Signature::PublicKey& pk) const override
{
const auto& rpk = dynamic_cast<const Group::PublicKey&>(pk);
return group.serialize(rpk);
}
std::unique_ptr<Signature::PublicKey> deserialize(
const bytes& enc) const override
{
return group.deserialize(enc);
}
bytes serialize_private(const Signature::PrivateKey& sk) const override
{
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
return group.serialize_private(*rsk.group_priv);
}
std::unique_ptr<Signature::PrivateKey> deserialize_private(
const bytes& skm) const override
{
return std::make_unique<PrivateKey>(
group.deserialize_private(skm).release());
}
bytes sign(const bytes& data, const Signature::PrivateKey& sk) const override
{
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
return group.sign(data, *rsk.group_priv);
}
bool verify(const bytes& data,
const bytes& sig,
const Signature::PublicKey& pk) const override
{
const auto& rpk = dynamic_cast<const Group::PublicKey&>(pk);
return group.verify(data, sig, rpk);
}
std::unique_ptr<Signature::PrivateKey> import_jwk_private(
const std::string& jwk_json) const override
{
const auto jwk = validate_jwk_json(jwk_json, true);
const auto d = from_base64url(jwk.at("d"));
auto gsk = group.deserialize_private(d);
return std::make_unique<PrivateKey>(gsk.release());
}
std::unique_ptr<Signature::PublicKey> import_jwk(
const std::string& jwk_json) const override
{
const auto jwk = validate_jwk_json(jwk_json, false);
const auto x = from_base64url(jwk.at("x"));
auto y = bytes{};
if (jwk.contains("y")) {
y = from_base64url(jwk.at("y"));
}
return group.public_key_from_coordinates(x, y);
}
std::string export_jwk(const Signature::PublicKey& pk) const override
{
const auto& gpk = dynamic_cast<const Group::PublicKey&>(pk);
const auto jwk_json = export_jwk_json(gpk);
return jwk_json.dump();
}
std::string export_jwk_private(const Signature::PrivateKey& sk) const override
{
const auto& gssk = dynamic_cast<const GroupSignature::PrivateKey&>(sk);
const auto& gsk = gssk.group_priv;
const auto gpk = gsk->public_key();
auto jwk_json = export_jwk_json(*gpk);
// encode the private key
const auto enc = serialize_private(sk);
jwk_json.emplace("d", to_base64url(enc));
return jwk_json.dump();
}
private:
const Group& group;
json validate_jwk_json(const std::string& jwk_json, bool private_key) const
{
json jwk = json::parse(jwk_json);
if (jwk.empty() || !jwk.contains("kty") || !jwk.contains("crv") ||
!jwk.contains("x") || (private_key && !jwk.contains("d"))) {
throw std::runtime_error("malformed JWK");
}
if (jwk.at("kty") != group.jwk_key_type) {
throw std::runtime_error("invalid JWK key type");
}
if (jwk.at("crv") != group.jwk_curve_name) {
throw std::runtime_error("invalid JWK curve");
}
return jwk;
}
json export_jwk_json(const Group::PublicKey& pk) const
{
const auto [x, y] = group.coordinates(pk);
json jwk = json::object({
{ "crv", group.jwk_curve_name },
{ "kty", group.jwk_key_type },
});
if (group.jwk_key_type == "EC") {
jwk.emplace("x", to_base64url(x));
jwk.emplace("y", to_base64url(y));
} else if (group.jwk_key_type == "OKP") {
jwk.emplace("x", to_base64url(x));
} else {
throw std::runtime_error("unknown key type");
}
return jwk;
}
};
template<>
const Signature&
Signature::get<Signature::ID::P256_SHA256>()
{
static const auto instance = GroupSignature(Group::get<Group::ID::P256>());
return instance;
}
template<>
const Signature&
Signature::get<Signature::ID::P384_SHA384>()
{
static const auto instance = GroupSignature(Group::get<Group::ID::P384>());
return instance;
}
template<>
const Signature&
Signature::get<Signature::ID::P521_SHA512>()
{
static const auto instance = GroupSignature(Group::get<Group::ID::P521>());
return instance;
}
template<>
const Signature&
Signature::get<Signature::ID::Ed25519>()
{
static const auto instance = GroupSignature(Group::get<Group::ID::Ed25519>());
return instance;
}
#if !defined(WITH_BORINGSSL)
template<>
const Signature&
Signature::get<Signature::ID::Ed448>()
{
static const auto instance = GroupSignature(Group::get<Group::ID::Ed448>());
return instance;
}
#endif
template<>
const Signature&
Signature::get<Signature::ID::RSA_SHA256>()
{
static const auto instance = RSASignature(Digest::ID::SHA256);
return instance;
}
template<>
const Signature&
Signature::get<Signature::ID::RSA_SHA384>()
{
static const auto instance = RSASignature(Digest::ID::SHA384);
return instance;
}
template<>
const Signature&
Signature::get<Signature::ID::RSA_SHA512>()
{
static const auto instance = RSASignature(Digest::ID::SHA512);
return instance;
}
Signature::Signature(Signature::ID id_in)
: id(id_in)
{
}
std::unique_ptr<Signature::PrivateKey>
Signature::generate_rsa(size_t bits)
{
return RSASignature::generate_key_pair(bits);
}
static const Signature&
sig_from_jwk(const std::string& jwk_json)
{
using KeyTypeAndCurve = std::tuple<std::string, std::string>;
static const auto alg_sig_map = std::map<KeyTypeAndCurve, const Signature&>
{
{ { "EC", "P-256" }, Signature::get<Signature::ID::P256_SHA256>() },
{ { "EC", "P-384" }, Signature::get<Signature::ID::P384_SHA384>() },
{ { "EC", "P-512" }, Signature::get<Signature::ID::P521_SHA512>() },
{ { "OKP", "Ed25519" }, Signature::get<Signature::ID::Ed25519>() },
#if !defined(WITH_BORINGSSL)
{ { "OKP", "Ed448" }, Signature::get<Signature::ID::Ed448>() },
#endif
// TODO(RLB): RSA
};
const auto jwk = json::parse(jwk_json);
const auto& kty = jwk.at("kty");
auto crv = std::string("");
if (jwk.contains("crv")) {
crv = jwk.at("crv");
}
const auto key = KeyTypeAndCurve{ kty, crv };
return alg_sig_map.at(key);
}
Signature::PrivateJWK
Signature::parse_jwk_private(const std::string& jwk_json)
{
// XXX(RLB): This JSON-parses the JWK twice. I'm assuming that this is a less
// bad cost than changing the import_jwk method signature to take `json`.
const auto& sig = sig_from_jwk(jwk_json);
const auto jwk = json::parse(jwk_json);
auto priv = sig.import_jwk_private(jwk_json);
auto kid = std::optional<std::string>{};
if (jwk.contains("kid")) {
kid = jwk.at("kid").get<std::string>();
}
return { sig, kid, std::move(priv) };
}
Signature::PublicJWK
Signature::parse_jwk(const std::string& jwk_json)
{
// XXX(RLB): Same double-parsing comment as with `parse_jwk_private`
const auto& sig = sig_from_jwk(jwk_json);
const auto jwk = json::parse(jwk_json);
auto pub = sig.import_jwk(jwk_json);
auto kid = std::optional<std::string>{};
if (jwk.contains("kid")) {
kid = jwk.at("kid").get<std::string>();
}
return { sig, kid, std::move(pub) };
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,401 @@
#include <hpke/base64.h>
#include <hpke/signature.h>
#include <hpke/userinfo_vc.h>
#include <dpp/json.h>
#include <tls/compat.h>
using nlohmann::json;
namespace mlspp::hpke {
static const std::string name_attr = "name";
static const std::string sub_attr = "sub";
static const std::string given_name_attr = "given_name";
static const std::string family_name_attr = "family_name";
static const std::string middle_name_attr = "middle_name";
static const std::string nickname_attr = "nickname";
static const std::string preferred_username_attr = "preferred_username";
static const std::string profile_attr = "profile";
static const std::string picture_attr = "picture";
static const std::string website_attr = "website";
static const std::string email_attr = "email";
static const std::string email_verified_attr = "email_verified";
static const std::string gender_attr = "gender";
static const std::string birthdate_attr = "birthdate";
static const std::string zoneinfo_attr = "zoneinfo";
static const std::string locale_attr = "locale";
static const std::string phone_number_attr = "phone_number";
static const std::string phone_number_verified_attr = "phone_number_verified";
static const std::string address_attr = "address";
static const std::string address_formatted_attr = "formatted";
static const std::string address_street_address_attr = "street_address";
static const std::string address_locality_attr = "locality";
static const std::string address_region_attr = "region";
static const std::string address_postal_code_attr = "postal_code";
static const std::string address_country_attr = "country";
static const std::string updated_at_attr = "updated_at";
template<typename T>
static std::optional<T>
get_optional(const json& json_object, const std::string& field_name)
{
if (!json_object.contains(field_name)) {
return std::nullopt;
}
return { json_object.at(field_name).get<T>() };
}
///
/// ParsedCredential
///
static const Signature&
signature_from_alg(const std::string& alg)
{
static const auto alg_sig_map = std::map<std::string, const Signature&>
{
{ "ES256", Signature::get<Signature::ID::P256_SHA256>() },
{ "ES384", Signature::get<Signature::ID::P384_SHA384>() },
{ "ES512", Signature::get<Signature::ID::P521_SHA512>() },
{ "Ed25519", Signature::get<Signature::ID::Ed25519>() },
#if !defined(WITH_BORINGSSL)
{ "Ed448", Signature::get<Signature::ID::Ed448>() },
#endif
{ "RS256", Signature::get<Signature::ID::RSA_SHA256>() },
{ "RS384", Signature::get<Signature::ID::RSA_SHA384>() },
{ "RS512", Signature::get<Signature::ID::RSA_SHA512>() },
};
return alg_sig_map.at(alg);
}
static std::chrono::system_clock::time_point
epoch_time(int64_t seconds_since_epoch)
{
const auto delta = std::chrono::seconds(seconds_since_epoch);
return std::chrono::system_clock::time_point(delta);
}
static bool
is_ecdsa(const Signature& sig)
{
return sig.id == Signature::ID::P256_SHA256 ||
sig.id == Signature::ID::P384_SHA384 ||
sig.id == Signature::ID::P521_SHA512;
}
// OpenSSL expects ECDSA signatures to be in DER form. JWS provides the
// signature in raw R||S form. So we need to do some manual DER encoding.
static bytes
jws_to_der_sig(const bytes& jws_sig)
{
// Inputs that are too large will result in invalid DER encodings with this
// code. At this size, the combination of the DER integer headers and the
// integer data will overflow the one-byte DER struct length.
static const auto max_sig_size = size_t(250);
if (jws_sig.size() > max_sig_size) {
throw std::runtime_error("JWS signature too large");
}
if (jws_sig.size() % 2 != 0) {
throw std::runtime_error("Malformed JWS signature");
}
const auto int_size = jws_sig.size() / 2;
const auto jws_sig_cut =
jws_sig.begin() + static_cast<std::ptrdiff_t>(int_size);
// Compute the encoded size of R and S integer data, adding a zero byte if
// needed to clear the sign bit
const auto r_big = (jws_sig.at(0) >= 0x80);
const auto s_big = (jws_sig.at(int_size) >= 0x80);
const auto r_size = int_size + (r_big ? 1 : 0);
const auto s_size = int_size + (s_big ? 1 : 0);
// Compute the size of the DER-encoded signature
static const auto int_header_size = 2;
const auto r_int_size = int_header_size + r_size;
const auto s_int_size = int_header_size + s_size;
const auto content_size = r_int_size + s_int_size;
const auto content_big = (content_size > 0x80);
auto der_header_size = 2 + (content_big ? 1 : 0);
const auto der_size = der_header_size + content_size;
// Allocate the DER buffer
auto der = bytes(der_size, 0);
// Write the header
der.at(0) = 0x30;
if (content_big) {
der.at(1) = 0x81;
der.at(2) = static_cast<uint8_t>(content_size);
} else {
der.at(1) = static_cast<uint8_t>(content_size);
}
// Write R, virtually padding with a zero byte if needed
const auto r_start = der_header_size;
const auto r_data_start = r_start + int_header_size + (r_big ? 1 : 0);
const auto r_data_begin =
der.begin() + static_cast<std::ptrdiff_t>(r_data_start);
der.at(r_start) = 0x02;
der.at(r_start + 1) = static_cast<uint8_t>(r_size);
std::copy(jws_sig.begin(), jws_sig_cut, r_data_begin);
// Write S, virtually padding with a zero byte if needed
const auto s_start = der_header_size + r_int_size;
const auto s_data_start = s_start + int_header_size + (s_big ? 1 : 0);
const auto s_data_begin =
der.begin() + static_cast<std::ptrdiff_t>(s_data_start);
der.at(s_start) = 0x02;
der.at(s_start + 1) = static_cast<uint8_t>(s_size);
std::copy(jws_sig_cut, jws_sig.end(), s_data_begin);
return der;
}
struct UserInfoVC::ParsedCredential
{
// Header fields
const Signature& signature_algorithm; // `alg`
std::optional<std::string> key_id; // `kid`
// Top-level Payload fields
std::string issuer; // `iss`
std::chrono::system_clock::time_point not_before; // `nbf`
std::chrono::system_clock::time_point not_after; // `exp`
// Credential subject fields
UserInfoClaims credential_subject;
Signature::PublicJWK public_key;
// Signature verification information
bytes to_be_signed;
bytes signature;
ParsedCredential(const Signature& signature_algorithm_in,
std::optional<std::string> key_id_in,
std::string issuer_in,
std::chrono::system_clock::time_point not_before_in,
std::chrono::system_clock::time_point not_after_in,
UserInfoClaims credential_subject_in,
Signature::PublicJWK&& public_key_in,
bytes to_be_signed_in,
bytes signature_in)
: signature_algorithm(signature_algorithm_in)
, key_id(std::move(key_id_in))
, issuer(std::move(issuer_in))
, not_before(not_before_in)
, not_after(not_after_in)
, credential_subject(std::move(credential_subject_in))
, public_key(std::move(public_key_in))
, to_be_signed(std::move(to_be_signed_in))
, signature(std::move(signature_in))
{
}
static std::shared_ptr<ParsedCredential> parse(const std::string& jwt)
{
// Split the JWT into its header, payload, and signature
const auto first_dot = jwt.find_first_of('.');
const auto last_dot = jwt.find_last_of('.');
if (first_dot == std::string::npos || last_dot == std::string::npos ||
first_dot == last_dot || last_dot > jwt.length() - 2) {
throw std::runtime_error("malformed JWT; not enough '.' characters");
}
const auto header_b64 = jwt.substr(0, first_dot);
const auto payload_b64 =
jwt.substr(first_dot + 1, last_dot - first_dot - 1);
const auto signature_b64 = jwt.substr(last_dot + 1);
// Parse the components
const auto header = json::parse(to_ascii(from_base64url(header_b64)));
const auto payload = json::parse(to_ascii(from_base64url(payload_b64)));
// Prepare the validation inputs
const auto hdr = header.at("alg");
const auto& sig = signature_from_alg(hdr);
const auto to_be_signed = from_ascii(header_b64 + "." + payload_b64);
auto signature = from_base64url(signature_b64);
if (is_ecdsa(sig)) {
signature = jws_to_der_sig(signature);
}
auto kid = std::optional<std::string>{};
if (header.contains("kid")) {
kid = header.at("kid").get<std::string>();
}
// Verify the VC parts
const auto& vc = payload.at("vc");
static const auto context =
std::vector<std::string>{ { "https://www.w3.org/2018/credentials/v1" } };
const auto vc_context = vc.at("@context").get<std::vector<std::string>>();
if (vc_context != context) {
throw std::runtime_error("malformed VC: incorrect context value");
}
static const auto type = std::vector<std::string>{
"VerifiableCredential",
"UserInfoCredential",
};
if (vc.at("type") != type) {
throw std::runtime_error("malformed VC: incorrect type value");
}
// Parse the subject public key
static const std::string did_jwk_prefix = "did:jwk:";
const auto id = vc.at("credentialSubject").at("id").get<std::string>();
if (id.find(did_jwk_prefix) != 0) {
throw std::runtime_error("malformed UserInfo VC: ID is not did:jwk");
}
const auto jwk = to_ascii(from_base64url(id.substr(did_jwk_prefix.size())));
auto public_key = Signature::parse_jwk(jwk);
// Extract the salient parts
return std::make_shared<ParsedCredential>(
sig,
kid,
payload.at("iss"),
epoch_time(payload.at("nbf").get<int64_t>()),
epoch_time(payload.at("exp").get<int64_t>()),
UserInfoClaims::from_json(vc.at("credentialSubject").dump()),
std::move(public_key),
to_be_signed,
signature);
}
bool verify(const Signature::PublicKey& issuer_key)
{
return signature_algorithm.verify(to_be_signed, signature, issuer_key);
}
};
///
/// UserInfoClaims
///
UserInfoClaims
UserInfoClaims::from_json(const std::string& cred_subject)
{
const auto& cred_subject_json = nlohmann::json::parse(cred_subject);
std::optional<UserInfoClaimsAddress> address_opt = {};
if (cred_subject_json.contains(address_attr)) {
auto address_json = cred_subject_json.at(address_attr);
address_opt = {
get_optional<std::string>(address_json, address_formatted_attr),
get_optional<std::string>(address_json, address_street_address_attr),
get_optional<std::string>(address_json, address_locality_attr),
get_optional<std::string>(address_json, address_region_attr),
get_optional<std::string>(address_json, address_postal_code_attr),
get_optional<std::string>(address_json, address_country_attr)
};
}
return {
get_optional<std::string>(cred_subject_json, sub_attr),
get_optional<std::string>(cred_subject_json, name_attr),
get_optional<std::string>(cred_subject_json, given_name_attr),
get_optional<std::string>(cred_subject_json, family_name_attr),
get_optional<std::string>(cred_subject_json, middle_name_attr),
get_optional<std::string>(cred_subject_json, nickname_attr),
get_optional<std::string>(cred_subject_json, preferred_username_attr),
get_optional<std::string>(cred_subject_json, profile_attr),
get_optional<std::string>(cred_subject_json, picture_attr),
get_optional<std::string>(cred_subject_json, website_attr),
get_optional<std::string>(cred_subject_json, email_attr),
get_optional<bool>(cred_subject_json, email_verified_attr),
get_optional<std::string>(cred_subject_json, gender_attr),
get_optional<std::string>(cred_subject_json, birthdate_attr),
get_optional<std::string>(cred_subject_json, zoneinfo_attr),
get_optional<std::string>(cred_subject_json, locale_attr),
get_optional<std::string>(cred_subject_json, phone_number_attr),
get_optional<bool>(cred_subject_json, phone_number_verified_attr),
address_opt,
get_optional<uint64_t>(cred_subject_json, updated_at_attr),
};
}
///
/// UserInfoVC
///
UserInfoVC::UserInfoVC(std::string jwt)
: parsed_cred(ParsedCredential::parse(jwt))
, raw(std::move(jwt))
{
}
const Signature&
UserInfoVC::signature_algorithm() const
{
return parsed_cred->signature_algorithm;
}
std::string
UserInfoVC::issuer() const
{
return parsed_cred->issuer;
}
std::optional<std::string>
UserInfoVC::key_id() const
{
return parsed_cred->key_id;
}
bool
UserInfoVC::valid_from(const Signature::PublicKey& issuer_key) const
{
return parsed_cred->verify(issuer_key);
}
const std::string&
UserInfoVC::raw_credential() const
{
return raw;
}
const UserInfoClaims&
UserInfoVC::subject() const
{
return parsed_cred->credential_subject;
}
std::chrono::system_clock::time_point
UserInfoVC::not_before() const
{
return parsed_cred->not_before;
}
std::chrono::system_clock::time_point
UserInfoVC::not_after() const
{
return parsed_cred->not_after;
}
const Signature::PublicJWK&
UserInfoVC::public_key() const
{
return parsed_cred->public_key;
}
bool
operator==(const UserInfoVC& lhs, const UserInfoVC& rhs)
{
return lhs.raw == rhs.raw;
}
} // namespace mlspp::hpke

View File

@@ -0,0 +1,24 @@
set(CURRENT_LIB_NAME mls_vectors)
###
### Library Config
###
file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES})
add_dependencies(${CURRENT_LIB_NAME} mlspp)
target_link_libraries(${CURRENT_LIB_NAME} mlspp bytes tls_syntax)
target_include_directories(${CURRENT_LIB_NAME}
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include/${PROJECT_NAME}>
)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include")

View File

@@ -0,0 +1,577 @@
#pragma once
#include <bytes/bytes.h>
#include <mls/crypto.h>
#include <mls/key_schedule.h>
#include <mls/messages.h>
#include <mls/tree_math.h>
#include <mls/treekem.h>
#include <tls/tls_syntax.h>
#include <vector>
namespace mls_vectors {
struct PseudoRandom
{
struct Generator
{
Generator() = default;
Generator(mlspp::CipherSuite suite_in, const std::string& label);
Generator sub(const std::string& label) const;
bytes secret(const std::string& label) const;
bytes generate(const std::string& label, size_t size) const;
uint16_t uint16(const std::string& label) const;
uint32_t uint32(const std::string& label) const;
uint64_t uint64(const std::string& label) const;
mlspp::SignaturePrivateKey signature_key(
const std::string& label) const;
mlspp::HPKEPrivateKey hpke_key(const std::string& label) const;
size_t output_length() const;
private:
mlspp::CipherSuite suite;
bytes seed;
Generator(mlspp::CipherSuite suite_in, bytes seed_in);
};
PseudoRandom() = default;
PseudoRandom(mlspp::CipherSuite suite, const std::string& label);
Generator prg;
};
struct TreeMathTestVector
{
using OptionalNode = std::optional<mlspp::NodeIndex>;
mlspp::LeafCount n_leaves;
mlspp::NodeCount n_nodes;
mlspp::NodeIndex root;
std::vector<OptionalNode> left;
std::vector<OptionalNode> right;
std::vector<OptionalNode> parent;
std::vector<OptionalNode> sibling;
std::optional<mlspp::NodeIndex> null_if_invalid(
mlspp::NodeIndex input,
mlspp::NodeIndex answer) const;
TreeMathTestVector() = default;
TreeMathTestVector(uint32_t n_leaves);
std::optional<std::string> verify() const;
};
struct CryptoBasicsTestVector : PseudoRandom
{
struct RefHash
{
std::string label;
bytes value;
bytes out;
RefHash() = default;
RefHash(mlspp::CipherSuite suite,
const PseudoRandom::Generator& prg);
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
};
struct ExpandWithLabel
{
bytes secret;
std::string label;
bytes context;
uint16_t length;
bytes out;
ExpandWithLabel() = default;
ExpandWithLabel(mlspp::CipherSuite suite,
const PseudoRandom::Generator& prg);
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
};
struct DeriveSecret
{
bytes secret;
std::string label;
bytes out;
DeriveSecret() = default;
DeriveSecret(mlspp::CipherSuite suite,
const PseudoRandom::Generator& prg);
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
};
struct DeriveTreeSecret
{
bytes secret;
std::string label;
uint32_t generation;
uint16_t length;
bytes out;
DeriveTreeSecret() = default;
DeriveTreeSecret(mlspp::CipherSuite suite,
const PseudoRandom::Generator& prg);
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
};
struct SignWithLabel
{
mlspp::SignaturePrivateKey priv;
mlspp::SignaturePublicKey pub;
bytes content;
std::string label;
bytes signature;
SignWithLabel() = default;
SignWithLabel(mlspp::CipherSuite suite,
const PseudoRandom::Generator& prg);
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
};
struct EncryptWithLabel
{
mlspp::HPKEPrivateKey priv;
mlspp::HPKEPublicKey pub;
std::string label;
bytes context;
bytes plaintext;
bytes kem_output;
bytes ciphertext;
EncryptWithLabel() = default;
EncryptWithLabel(mlspp::CipherSuite suite,
const PseudoRandom::Generator& prg);
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
};
mlspp::CipherSuite cipher_suite;
RefHash ref_hash;
ExpandWithLabel expand_with_label;
DeriveSecret derive_secret;
DeriveTreeSecret derive_tree_secret;
SignWithLabel sign_with_label;
EncryptWithLabel encrypt_with_label;
CryptoBasicsTestVector() = default;
CryptoBasicsTestVector(mlspp::CipherSuite suite);
std::optional<std::string> verify() const;
};
struct SecretTreeTestVector : PseudoRandom
{
struct SenderData
{
bytes sender_data_secret;
bytes ciphertext;
bytes key;
bytes nonce;
SenderData() = default;
SenderData(mlspp::CipherSuite suite,
const PseudoRandom::Generator& prg);
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
};
struct RatchetStep
{
uint32_t generation;
bytes handshake_key;
bytes handshake_nonce;
bytes application_key;
bytes application_nonce;
};
mlspp::CipherSuite cipher_suite;
SenderData sender_data;
bytes encryption_secret;
std::vector<std::vector<RatchetStep>> leaves;
SecretTreeTestVector() = default;
SecretTreeTestVector(mlspp::CipherSuite suite,
uint32_t n_leaves,
const std::vector<uint32_t>& generations);
std::optional<std::string> verify() const;
};
struct KeyScheduleTestVector : PseudoRandom
{
struct Export
{
std::string label;
bytes context;
size_t length;
bytes secret;
};
struct Epoch
{
// Chosen by the generator
bytes tree_hash;
bytes commit_secret;
bytes psk_secret;
bytes confirmed_transcript_hash;
// Computed values
bytes group_context;
bytes joiner_secret;
bytes welcome_secret;
bytes init_secret;
bytes sender_data_secret;
bytes encryption_secret;
bytes exporter_secret;
bytes epoch_authenticator;
bytes external_secret;
bytes confirmation_key;
bytes membership_key;
bytes resumption_psk;
mlspp::HPKEPublicKey external_pub;
Export exporter;
};
mlspp::CipherSuite cipher_suite;
bytes group_id;
bytes initial_init_secret;
std::vector<Epoch> epochs;
KeyScheduleTestVector() = default;
KeyScheduleTestVector(mlspp::CipherSuite suite, uint32_t n_epochs);
std::optional<std::string> verify() const;
};
struct MessageProtectionTestVector : PseudoRandom
{
mlspp::CipherSuite cipher_suite;
bytes group_id;
mlspp::epoch_t epoch;
bytes tree_hash;
bytes confirmed_transcript_hash;
mlspp::SignaturePrivateKey signature_priv;
mlspp::SignaturePublicKey signature_pub;
bytes encryption_secret;
bytes sender_data_secret;
bytes membership_key;
mlspp::Proposal proposal;
mlspp::MLSMessage proposal_pub;
mlspp::MLSMessage proposal_priv;
mlspp::Commit commit;
mlspp::MLSMessage commit_pub;
mlspp::MLSMessage commit_priv;
bytes application;
mlspp::MLSMessage application_priv;
MessageProtectionTestVector() = default;
MessageProtectionTestVector(mlspp::CipherSuite suite);
std::optional<std::string> verify();
private:
mlspp::GroupKeySource group_keys() const;
mlspp::GroupContext group_context() const;
mlspp::MLSMessage protect_pub(
const mlspp::GroupContent::RawContent& raw_content) const;
mlspp::MLSMessage protect_priv(
const mlspp::GroupContent::RawContent& raw_content);
std::optional<mlspp::GroupContent> unprotect(
const mlspp::MLSMessage& message);
};
struct PSKSecretTestVector : PseudoRandom
{
struct PSK
{
bytes psk_id;
bytes psk_nonce;
bytes psk;
};
mlspp::CipherSuite cipher_suite;
std::vector<PSK> psks;
bytes psk_secret;
PSKSecretTestVector() = default;
PSKSecretTestVector(mlspp::CipherSuite suite, size_t n_psks);
std::optional<std::string> verify() const;
};
struct TranscriptTestVector : PseudoRandom
{
mlspp::CipherSuite cipher_suite;
bytes confirmation_key;
bytes interim_transcript_hash_before;
mlspp::AuthenticatedContent authenticated_content;
bytes confirmed_transcript_hash_after;
bytes interim_transcript_hash_after;
TranscriptTestVector() = default;
TranscriptTestVector(mlspp::CipherSuite suite);
std::optional<std::string> verify() const;
};
struct WelcomeTestVector : PseudoRandom
{
mlspp::CipherSuite cipher_suite;
mlspp::HPKEPrivateKey init_priv;
mlspp::SignaturePublicKey signer_pub;
mlspp::MLSMessage key_package;
mlspp::MLSMessage welcome;
WelcomeTestVector() = default;
WelcomeTestVector(mlspp::CipherSuite suite);
std::optional<std::string> verify() const;
};
// XXX(RLB): The |structure| of the example trees below is to avoid compile
// errors from gcc's -Werror=comment when there is a '\' character at the end of
// a line. Inspired by a similar bug in Chromium:
// https://codereview.chromium.org/874663003/patch/1/10001
enum struct TreeStructure
{
// Full trees on N leaves, created by member k adding member k+1
full_tree_2,
full_tree_3,
full_tree_4,
full_tree_5,
full_tree_6,
full_tree_7,
full_tree_8,
full_tree_32,
full_tree_33,
full_tree_34,
// | W |
// | ______|______ |
// | / \ |
// | U Y |
// | __|__ __|__ |
// | / \ / \ |
// | T _ X Z |
// | / \ / \ / \ / \ |
// | A B C _ E F G H |
//
// * Start with full tree on 8 members
// * 0 commits removeing 2 and 3, and adding a new member
internal_blanks_no_skipping,
// | W |
// | ______|______ |
// | / \ |
// | _ Y |
// | __|__ __|__ |
// | / \ / \ |
// | _ _ X Z |
// | / \ / \ / \ / \ |
// | A _ _ _ E F G H |
//
// * Start with full tree on 8 members
// * 0 commitsremoveing 1, 2, and 3
internal_blanks_with_skipping,
// | W[H] |
// | ______|______ |
// | / \ |
// | U Y[H] |
// | __|__ __|__ |
// | / \ / \ |
// | T V X _ |
// | / \ / \ / \ / \ |
// | A B C D E F G H |
//
// * Start with full tree on 7 members
// * 0 commits adding a member in a partial Commit (no path)
unmerged_leaves_no_skipping,
// | W [F] |
// | ______|______ |
// | / \ |
// | U Y [F] |
// | __|__ __|__ |
// | / \ / \ |
// | T _ _ _ |
// | / \ / \ / \ / \ |
// | A B C D E F G _ |
//
// == Fig. 20 / {{parent-hash-tree}}
// * 0 creates group
// * 0 adds 1, ..., 6 in a partial Commit
// * O commits removing 5
// * 4 commits without any proposals
// * 0 commits adding a new member in a partial Commit
unmerged_leaves_with_skipping,
};
extern std::array<TreeStructure, 14> all_tree_structures;
extern std::array<TreeStructure, 11> treekem_test_tree_structures;
struct TreeHashTestVector : PseudoRandom
{
mlspp::CipherSuite cipher_suite;
bytes group_id;
mlspp::TreeKEMPublicKey tree;
std::vector<bytes> tree_hashes;
std::vector<std::vector<mlspp::NodeIndex>> resolutions;
TreeHashTestVector() = default;
TreeHashTestVector(mlspp::CipherSuite suite,
TreeStructure tree_structure);
std::optional<std::string> verify();
};
struct TreeOperationsTestVector : PseudoRandom
{
enum struct Scenario
{
add_right_edge,
add_internal,
update,
remove_right_edge,
remove_internal,
};
static const std::vector<Scenario> all_scenarios;
mlspp::CipherSuite cipher_suite;
mlspp::TreeKEMPublicKey tree_before;
bytes tree_hash_before;
mlspp::Proposal proposal;
mlspp::LeafIndex proposal_sender;
mlspp::TreeKEMPublicKey tree_after;
bytes tree_hash_after;
TreeOperationsTestVector() = default;
TreeOperationsTestVector(mlspp::CipherSuite suite, Scenario scenario);
std::optional<std::string> verify();
};
struct TreeKEMTestVector : PseudoRandom
{
struct PathSecret
{
mlspp::NodeIndex node;
bytes path_secret;
};
struct LeafPrivateInfo
{
mlspp::LeafIndex index;
mlspp::HPKEPrivateKey encryption_priv;
mlspp::SignaturePrivateKey signature_priv;
std::vector<PathSecret> path_secrets;
};
struct UpdatePathInfo
{
mlspp::LeafIndex sender;
mlspp::UpdatePath update_path;
std::vector<std::optional<bytes>> path_secrets;
bytes commit_secret;
bytes tree_hash_after;
};
mlspp::CipherSuite cipher_suite;
bytes group_id;
mlspp::epoch_t epoch;
bytes confirmed_transcript_hash;
mlspp::TreeKEMPublicKey ratchet_tree;
std::vector<LeafPrivateInfo> leaves_private;
std::vector<UpdatePathInfo> update_paths;
TreeKEMTestVector() = default;
TreeKEMTestVector(mlspp::CipherSuite suite,
TreeStructure tree_structure);
std::optional<std::string> verify();
};
struct MessagesTestVector : PseudoRandom
{
bytes mls_welcome;
bytes mls_group_info;
bytes mls_key_package;
bytes ratchet_tree;
bytes group_secrets;
bytes add_proposal;
bytes update_proposal;
bytes remove_proposal;
bytes pre_shared_key_proposal;
bytes re_init_proposal;
bytes external_init_proposal;
bytes group_context_extensions_proposal;
bytes commit;
bytes public_message_proposal;
bytes public_message_commit;
bytes private_message;
MessagesTestVector();
std::optional<std::string> verify() const;
};
struct PassiveClientTestVector : PseudoRandom
{
struct PSK
{
bytes psk_id;
bytes psk;
};
struct Epoch
{
std::vector<mlspp::MLSMessage> proposals;
mlspp::MLSMessage commit;
bytes epoch_authenticator;
};
mlspp::CipherSuite cipher_suite;
mlspp::MLSMessage key_package;
mlspp::SignaturePrivateKey signature_priv;
mlspp::HPKEPrivateKey encryption_priv;
mlspp::HPKEPrivateKey init_priv;
std::vector<PSK> external_psks;
mlspp::MLSMessage welcome;
std::optional<mlspp::TreeKEMPublicKey> ratchet_tree;
bytes initial_epoch_authenticator;
std::vector<Epoch> epochs;
PassiveClientTestVector() = default;
std::optional<std::string> verify();
};
} // namespace mls_vectors

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
set(CURRENT_LIB_NAME tls_syntax)
###
### Library Config
###
file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES})
target_include_directories(${CURRENT_LIB_NAME}
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include/${PROJECT_NAME}>
)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include")

View File

@@ -0,0 +1,61 @@
#pragma once
#include <optional>
#include <stdexcept>
#ifdef VARIANT_COMPAT
#include <variant.hpp>
#else
#include <variant>
#endif // VARIANT_COMPAT
namespace mlspp::tls {
namespace var = std;
// In a similar vein, we provide our own safe accessors for std::optional, since
// std::optional::value() is not available on macOS 10.11.
namespace opt {
template<typename T>
T&
get(std::optional<T>& opt)
{
if (!opt) {
throw std::runtime_error("bad_optional_access");
}
return *opt;
}
template<typename T>
const T&
get(const std::optional<T>& opt)
{
if (!opt) {
throw std::runtime_error("bad_optional_access");
}
return *opt;
}
template<typename T>
T&&
get(std::optional<T>&& opt)
{
if (!opt) {
throw std::runtime_error("bad_optional_access");
}
return std::move(*opt);
}
template<typename T>
const T&&
get(const std::optional<T>&& opt)
{
if (!opt) {
throw std::runtime_error("bad_optional_access");
}
return std::move(*opt);
}
} // namespace opt
} // namespace mlspp::tls

View File

@@ -0,0 +1,569 @@
#pragma once
#include <algorithm>
#include <array>
#include <cstdint>
#include <limits>
#include <map>
#include <optional>
#include <stdexcept>
#include <vector>
#include <tls/compat.h>
namespace mlspp::tls {
// For indicating no min or max in vector definitions
const size_t none = std::numeric_limits<size_t>::max();
class WriteError : public std::invalid_argument
{
public:
using parent = std::invalid_argument;
using parent::parent;
};
class ReadError : public std::invalid_argument
{
public:
using parent = std::invalid_argument;
using parent::parent;
};
///
/// Declarations of Streams and Traits
///
class ostream
{
public:
static const size_t none = std::numeric_limits<size_t>::max();
void write_raw(const std::vector<uint8_t>& bytes);
const std::vector<uint8_t>& bytes() const { return _buffer; }
size_t size() const { return _buffer.size(); }
bool empty() const { return _buffer.empty(); }
private:
std::vector<uint8_t> _buffer;
ostream& write_uint(uint64_t value, int length);
friend ostream& operator<<(ostream& out, bool data);
friend ostream& operator<<(ostream& out, uint8_t data);
friend ostream& operator<<(ostream& out, uint16_t data);
friend ostream& operator<<(ostream& out, uint32_t data);
friend ostream& operator<<(ostream& out, uint64_t data);
template<typename T>
friend ostream& operator<<(ostream& out, const std::vector<T>& data);
friend struct varint;
};
class istream
{
public:
istream(const std::vector<uint8_t>& data)
: _buffer(data)
{
// So that we can use the constant-time pop_back
std::reverse(_buffer.begin(), _buffer.end());
}
size_t size() const { return _buffer.size(); }
bool empty() const { return _buffer.empty(); }
std::vector<uint8_t> bytes()
{
auto bytes = _buffer;
std::reverse(bytes.begin(), bytes.end());
return bytes;
}
private:
istream() {}
std::vector<uint8_t> _buffer;
uint8_t next();
template<typename T>
istream& read_uint(T& data, size_t length)
{
uint64_t value = 0;
for (size_t i = 0; i < length; i += 1) {
value = (value << unsigned(8)) + next();
}
data = static_cast<T>(value);
return *this;
}
friend istream& operator>>(istream& in, bool& data);
friend istream& operator>>(istream& in, uint8_t& data);
friend istream& operator>>(istream& in, uint16_t& data);
friend istream& operator>>(istream& in, uint32_t& data);
friend istream& operator>>(istream& in, uint64_t& data);
template<typename T>
friend istream& operator>>(istream& in, std::vector<T>& data);
friend struct varint;
};
// Traits must have static encode and decode methods, of the following form:
//
// static ostream& encode(ostream& str, const T& val);
// static istream& decode(istream& str, T& val);
//
// Trait types will never be constructed; only these static methods are used.
// The value arguments to encode and decode can be as strict or as loose as
// desired.
//
// Ultimately, all interesting encoding should be done through traits.
//
// * vectors
// * variants
// * varints
struct pass
{
template<typename T>
static ostream& encode(ostream& str, const T& val);
template<typename T>
static istream& decode(istream& str, T& val);
};
template<typename Ts>
struct variant
{
template<typename... Tp>
static inline Ts type(const var::variant<Tp...>& data);
template<typename... Tp>
static ostream& encode(ostream& str, const var::variant<Tp...>& data);
template<size_t I = 0, typename Te, typename... Tp>
static inline typename std::enable_if<I == sizeof...(Tp), void>::type
read_variant(istream&, Te, var::variant<Tp...>&);
template<size_t I = 0, typename Te, typename... Tp>
static inline typename std::enable_if <
I<sizeof...(Tp), void>::type read_variant(istream& str,
Te target_type,
var::variant<Tp...>& v);
template<typename... Tp>
static istream& decode(istream& str, var::variant<Tp...>& data);
};
struct varint
{
static ostream& encode(ostream& str, const uint64_t& val);
static istream& decode(istream& str, uint64_t& val);
};
///
/// Writer implementations
///
// Primitive writers defined in .cpp file
// Array writer
template<typename T, size_t N>
ostream&
operator<<(ostream& out, const std::array<T, N>& data)
{
for (const auto& item : data) {
out << item;
}
return out;
}
// Optional writer
template<typename T>
ostream&
operator<<(ostream& out, const std::optional<T>& opt)
{
if (!opt) {
return out << uint8_t(0);
}
return out << uint8_t(1) << opt::get(opt);
}
// Enum writer
template<typename T, std::enable_if_t<std::is_enum<T>::value, int> = 0>
ostream&
operator<<(ostream& str, const T& val)
{
auto u = static_cast<std::underlying_type_t<T>>(val);
return str << u;
}
// Vector writer
template<typename T>
ostream&
operator<<(ostream& str, const std::vector<T>& vec)
{
// Pre-encode contents
ostream temp;
for (const auto& item : vec) {
temp << item;
}
// Write the encoded length, then the pre-encoded data
varint::encode(str, temp._buffer.size());
str.write_raw(temp.bytes());
return str;
}
///
/// Reader implementations
///
// Primitive type readers defined in .cpp file
// Array reader
template<typename T, size_t N>
istream&
operator>>(istream& in, std::array<T, N>& data)
{
for (auto& item : data) {
in >> item;
}
return in;
}
// Optional reader
template<typename T>
istream&
operator>>(istream& in, std::optional<T>& opt)
{
uint8_t present = 0;
in >> present;
switch (present) {
case 0:
opt.reset();
return in;
case 1:
opt.emplace();
return in >> opt::get(opt);
default:
throw std::invalid_argument("Malformed optional");
}
}
// Enum reader
// XXX(rlb): It would be nice if this could enforce that the values are valid,
// but C++ doesn't seem to have that ability. When used as a tag for variants,
// the variant reader will enforce, at least.
template<typename T, std::enable_if_t<std::is_enum<T>::value, int> = 0>
istream&
operator>>(istream& str, T& val)
{
std::underlying_type_t<T> u;
str >> u;
val = static_cast<T>(u);
return str;
}
// Vector reader
template<typename T>
istream&
operator>>(istream& str, std::vector<T>& vec)
{
// Read the encoded data size
auto size = uint64_t(0);
varint::decode(str, size);
if (size > str._buffer.size()) {
throw ReadError("Vector is longer than remaining data");
}
// Read the elements of the vector
// NB: Remember that we store the vector in reverse order
// NB: This requires that T be default-constructible
istream r;
r._buffer =
std::vector<uint8_t>{ str._buffer.end() - size, str._buffer.end() };
vec.clear();
while (r._buffer.size() > 0) {
vec.emplace_back();
r >> vec.back();
}
// Truncate the primary buffer
str._buffer.erase(str._buffer.end() - size, str._buffer.end());
return str;
}
// Abbreviations
template<typename T>
std::vector<uint8_t>
marshal(const T& value)
{
ostream w;
w << value;
return w.bytes();
}
template<typename T>
void
unmarshal(const std::vector<uint8_t>& data, T& value)
{
istream r(data);
r >> value;
}
template<typename T, typename... Tp>
T
get(const std::vector<uint8_t>& data, Tp... args)
{
T value(args...);
unmarshal(data, value);
return value;
}
// Use this macro to define struct serialization with minimal boilerplate
#define TLS_SERIALIZABLE(...) \
static const bool _tls_serializable = true; \
auto _tls_fields_r() \
{ \
return std::forward_as_tuple(__VA_ARGS__); \
} \
auto _tls_fields_w() const \
{ \
return std::forward_as_tuple(__VA_ARGS__); \
}
// If your struct contains nontrivial members (e.g., vectors), use this to
// define traits for them.
#define TLS_TRAITS(...) \
static const bool _tls_has_traits = true; \
using _tls_traits = std::tuple<__VA_ARGS__>;
template<typename T>
struct is_serializable
{
template<typename U>
static std::true_type test(decltype(U::_tls_serializable));
template<typename U>
static std::false_type test(...);
static const bool value = decltype(test<T>(true))::value;
};
template<typename T>
struct has_traits
{
template<typename U>
static std::true_type test(decltype(U::_tls_has_traits));
template<typename U>
static std::false_type test(...);
static const bool value = decltype(test<T>(true))::value;
};
///
/// Trait implementations
///
// Pass-through (normal encoding/decoding)
template<typename T>
ostream&
pass::encode(ostream& str, const T& val)
{
return str << val;
}
template<typename T>
istream&
pass::decode(istream& str, T& val)
{
return str >> val;
}
// Variant encoding
template<typename Ts, typename Tv>
constexpr Ts
variant_map();
#define TLS_VARIANT_MAP(EnumType, MappedType, enum_value) \
template<> \
constexpr EnumType variant_map<EnumType, MappedType>() \
{ \
return EnumType::enum_value; \
}
template<typename Ts>
template<typename... Tp>
inline Ts
variant<Ts>::type(const var::variant<Tp...>& data)
{
const auto get_type = [](const auto& v) {
return variant_map<Ts, std::decay_t<decltype(v)>>();
};
return var::visit(get_type, data);
}
template<typename Ts>
template<typename... Tp>
ostream&
variant<Ts>::encode(ostream& str, const var::variant<Tp...>& data)
{
const auto write_variant = [&str](auto&& v) {
using Tv = std::decay_t<decltype(v)>;
str << variant_map<Ts, Tv>() << v;
};
var::visit(write_variant, data);
return str;
}
template<typename Ts>
template<size_t I, typename Te, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type
variant<Ts>::read_variant(istream&, Te, var::variant<Tp...>&)
{
throw ReadError("Invalid variant type label");
}
template<typename Ts>
template<size_t I, typename Te, typename... Tp>
inline
typename std::enable_if < I<sizeof...(Tp), void>::type
variant<Ts>::read_variant(istream& str,
Te target_type,
var::variant<Tp...>& v)
{
using Tc = var::variant_alternative_t<I, var::variant<Tp...>>;
if (variant_map<Ts, Tc>() == target_type) {
str >> v.template emplace<I>();
return;
}
read_variant<I + 1>(str, target_type, v);
}
template<typename Ts>
template<typename... Tp>
istream&
variant<Ts>::decode(istream& str, var::variant<Tp...>& data)
{
Ts target_type;
str >> target_type;
read_variant(str, target_type, data);
return str;
}
// Struct writer without traits (enabled by macro)
template<size_t I = 0, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type
write_tuple(ostream&, const std::tuple<Tp...>&)
{
}
template<size_t I = 0, typename... Tp>
inline typename std::enable_if <
I<sizeof...(Tp), void>::type
write_tuple(ostream& str, const std::tuple<Tp...>& t)
{
str << std::get<I>(t);
write_tuple<I + 1, Tp...>(str, t);
}
template<typename T>
inline
typename std::enable_if<is_serializable<T>::value && !has_traits<T>::value,
ostream&>::type
operator<<(ostream& str, const T& obj)
{
write_tuple(str, obj._tls_fields_w());
return str;
}
// Struct writer with traits (enabled by macro)
template<typename Tr, size_t I = 0, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type
write_tuple_traits(ostream&, const std::tuple<Tp...>&)
{
}
template<typename Tr, size_t I = 0, typename... Tp>
inline typename std::enable_if <
I<sizeof...(Tp), void>::type
write_tuple_traits(ostream& str, const std::tuple<Tp...>& t)
{
std::tuple_element_t<I, Tr>::encode(str, std::get<I>(t));
write_tuple_traits<Tr, I + 1, Tp...>(str, t);
}
template<typename T>
inline
typename std::enable_if<is_serializable<T>::value && has_traits<T>::value,
ostream&>::type
operator<<(ostream& str, const T& obj)
{
write_tuple_traits<typename T::_tls_traits>(str, obj._tls_fields_w());
return str;
}
// Struct reader without traits (enabled by macro)
template<size_t I = 0, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type
read_tuple(istream&, const std::tuple<Tp...>&)
{
}
template<size_t I = 0, typename... Tp>
inline
typename std::enable_if < I<sizeof...(Tp), void>::type
read_tuple(istream& str, const std::tuple<Tp...>& t)
{
str >> std::get<I>(t);
read_tuple<I + 1, Tp...>(str, t);
}
template<typename T>
inline
typename std::enable_if<is_serializable<T>::value && !has_traits<T>::value,
istream&>::type
operator>>(istream& str, T& obj)
{
read_tuple(str, obj._tls_fields_r());
return str;
}
// Struct reader with traits (enabled by macro)
template<typename Tr, size_t I = 0, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type
read_tuple_traits(istream&, const std::tuple<Tp...>&)
{
}
template<typename Tr, size_t I = 0, typename... Tp>
inline typename std::enable_if <
I<sizeof...(Tp), void>::type
read_tuple_traits(istream& str, const std::tuple<Tp...>& t)
{
std::tuple_element_t<I, Tr>::decode(str, std::get<I>(t));
read_tuple_traits<Tr, I + 1, Tp...>(str, t);
}
template<typename T>
inline
typename std::enable_if<is_serializable<T>::value && has_traits<T>::value,
istream&>::type
operator>>(istream& str, T& obj)
{
read_tuple_traits<typename T::_tls_traits>(str, obj._tls_fields_r());
return str;
}
} // namespace mlspp::tls

View File

@@ -0,0 +1,178 @@
#include <tls/tls_syntax.h>
// NOLINTNEXTLINE(llvmlibc-implementation-in-namespace)
namespace mlspp::tls {
void
ostream::write_raw(const std::vector<uint8_t>& bytes)
{
// Not sure what the default argument is here
_buffer.insert(_buffer.end(), bytes.begin(), bytes.end());
}
// Primitive type writers
ostream&
ostream::write_uint(uint64_t value, int length)
{
for (int i = length - 1; i >= 0; --i) {
_buffer.push_back(static_cast<uint8_t>(value >> unsigned(8 * i)));
}
return *this;
}
ostream&
operator<<(ostream& out, bool data)
{
if (data) {
return out << uint8_t(1);
}
return out << uint8_t(0);
}
ostream&
operator<<(ostream& out, uint8_t data) // NOLINT(llvmlibc-callee-namespace)
{
return out.write_uint(data, 1);
}
ostream&
operator<<(ostream& out, uint16_t data)
{
return out.write_uint(data, 2);
}
ostream&
operator<<(ostream& out, uint32_t data)
{
return out.write_uint(data, 4);
}
ostream&
operator<<(ostream& out, uint64_t data)
{
return out.write_uint(data, 8);
}
// Because pop_back() on an empty vector is undefined
uint8_t
istream::next()
{
if (_buffer.empty()) {
throw ReadError("Attempt to read from empty buffer");
}
const uint8_t value = _buffer.back();
_buffer.pop_back();
return value;
}
// Primitive type readers
istream&
operator>>(istream& in, bool& data)
{
uint8_t val = 0;
in >> val;
// Linter thinks uint8_t is signed (?)
// NOLINTNEXTLINE(hicpp-signed-bitwise)
if ((val & 0xFE) != 0) {
throw ReadError("Malformed boolean");
}
data = (val == 1);
return in;
}
istream&
operator>>(istream& in, uint8_t& data) // NOLINT(llvmlibc-callee-namespace)
{
return in.read_uint(data, 1);
}
istream&
operator>>(istream& in, uint16_t& data)
{
return in.read_uint(data, 2);
}
istream&
operator>>(istream& in, uint32_t& data)
{
return in.read_uint(data, 4);
}
istream&
operator>>(istream& in, uint64_t& data)
{
return in.read_uint(data, 8);
}
// Varint encoding
static constexpr size_t VARINT_HEADER_OFFSET = 6;
static constexpr uint64_t VARINT_1_HEADER = 0x00; // 0 << V1_OFFSET
static constexpr uint64_t VARINT_2_HEADER = 0x4000; // 1 << V2_OFFSET
static constexpr uint64_t VARINT_4_HEADER = 0x80000000; // 2 << V4_OFFSET
static constexpr uint64_t VARINT_1_MAX = 0x3f;
static constexpr uint64_t VARINT_2_MAX = 0x3fff;
static constexpr uint64_t VARINT_4_MAX = 0x3fffffff;
ostream&
varint::encode(ostream& str, const uint64_t& val)
{
if (val <= VARINT_1_MAX) {
return str.write_uint(VARINT_1_HEADER | val, 1);
}
if (val <= VARINT_2_MAX) {
return str.write_uint(VARINT_2_HEADER | val, 2);
}
if (val <= VARINT_4_MAX) {
return str.write_uint(VARINT_4_HEADER | val, 4);
}
throw WriteError("Varint value exceeds maximum size");
}
istream&
varint::decode(istream& str, uint64_t& val)
{
auto log_size = size_t(str._buffer.back() >> VARINT_HEADER_OFFSET);
if (log_size > 2) {
throw ReadError("Malformed varint header");
}
auto read = uint64_t(0);
auto read_bytes = size_t(size_t(1) << log_size);
str.read_uint(read, read_bytes);
switch (log_size) {
case 0:
read ^= VARINT_1_HEADER;
break;
case 1:
read ^= VARINT_2_HEADER;
if (read <= VARINT_1_MAX) {
throw ReadError("Non-minimal varint");
}
break;
case 2:
read ^= VARINT_4_HEADER;
if (read <= VARINT_2_MAX) {
throw ReadError("Non-minimal varint");
}
break;
default:
throw ReadError("Malformed varint header");
}
val = read;
return str;
}
} // namespace mlspp::tls

13
DPP/mlspp/src/common.cpp Executable file
View File

@@ -0,0 +1,13 @@
#include "mls/common.h"
namespace mlspp {
uint64_t
seconds_since_epoch()
{
// TODO(RLB) This should use std::chrono, but that seems not to be available
// on some platforms.
return std::time(nullptr);
}
} // namespace mlspp

443
DPP/mlspp/src/core_types.cpp Executable file
View File

@@ -0,0 +1,443 @@
#include "mls/core_types.h"
#include "mls/messages.h"
#include "grease.h"
#include <set>
namespace mlspp {
///
/// Extensions
///
const Extension::Type RequiredCapabilitiesExtension::type =
ExtensionType::required_capabilities;
const Extension::Type ApplicationIDExtension::type =
ExtensionType::application_id;
const std::array<uint16_t, 5> default_extensions = {
ExtensionType::application_id, ExtensionType::ratchet_tree,
ExtensionType::required_capabilities, ExtensionType::external_pub,
ExtensionType::external_senders,
};
const std::array<uint16_t, 8> default_proposals = {
ProposalType::add,
ProposalType::update,
ProposalType::remove,
ProposalType::psk,
ProposalType::reinit,
ProposalType::external_init,
ProposalType::group_context_extensions,
};
const std::array<ProtocolVersion, 1> all_supported_versions = {
ProtocolVersion::mls10
};
const std::array<CipherSuite::ID, 6> all_supported_ciphersuites = {
CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519,
CipherSuite::ID::P256_AES128GCM_SHA256_P256,
CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519,
CipherSuite::ID::X448_AES256GCM_SHA512_Ed448,
CipherSuite::ID::P521_AES256GCM_SHA512_P521,
CipherSuite::ID::X448_CHACHA20POLY1305_SHA512_Ed448,
};
const std::array<CredentialType, 4> all_supported_credentials = {
CredentialType::basic,
CredentialType::x509,
CredentialType::userinfo_vc_draft_00,
CredentialType::multi_draft_00
};
Capabilities
Capabilities::create_default()
{
return {
{ all_supported_versions.begin(), all_supported_versions.end() },
{ all_supported_ciphersuites.begin(), all_supported_ciphersuites.end() },
{ /* No non-default extensions */ },
{ /* No non-default proposals */ },
{ all_supported_credentials.begin(), all_supported_credentials.end() },
};
}
bool
Capabilities::extensions_supported(
const std::vector<Extension::Type>& required) const
{
return stdx::all_of(required, [&](Extension::Type type) {
if (stdx::contains(default_extensions, type)) {
return true;
}
return stdx::contains(extensions, type);
});
}
bool
Capabilities::proposals_supported(
const std::vector<Proposal::Type>& required) const
{
return stdx::all_of(required, [&](Proposal::Type type) {
if (stdx::contains(default_proposals, type)) {
return true;
}
return stdx::contains(proposals, type);
});
}
bool
Capabilities::credential_supported(const Credential& credential) const
{
return stdx::contains(credentials, credential.type());
}
Lifetime
Lifetime::create_default()
{
return Lifetime{ 0x0000000000000000, 0xffffffffffffffff };
}
void
ExtensionList::add(uint16_t type, bytes data)
{
auto curr = stdx::find_if(
extensions, [&](const Extension& ext) -> bool { return ext.type == type; });
if (curr != extensions.end()) {
curr->data = std::move(data);
return;
}
extensions.push_back({ type, std::move(data) });
}
bool
ExtensionList::has(uint16_t type) const
{
return stdx::any_of(extensions,
[&](const Extension& ext) { return ext.type == type; });
}
///
/// LeafNode
///
LeafNode::LeafNode(CipherSuite cipher_suite,
HPKEPublicKey encryption_key_in,
SignaturePublicKey signature_key_in,
Credential credential_in,
Capabilities capabilities_in,
Lifetime lifetime_in,
ExtensionList extensions_in,
const SignaturePrivateKey& sig_priv)
: encryption_key(std::move(encryption_key_in))
, signature_key(std::move(signature_key_in))
, credential(std::move(credential_in))
, capabilities(std::move(capabilities_in))
, content(lifetime_in)
, extensions(std::move(extensions_in))
{
grease(extensions);
grease(capabilities, extensions);
sign(cipher_suite, sig_priv, std::nullopt);
}
void
LeafNode::set_capabilities(Capabilities capabilities_in)
{
capabilities = std::move(capabilities_in);
grease(capabilities, extensions);
}
LeafNode
LeafNode::for_update(CipherSuite cipher_suite,
const bytes& group_id,
LeafIndex leaf_index,
HPKEPublicKey encryption_key_in,
const LeafNodeOptions& opts,
const SignaturePrivateKey& sig_priv) const
{
auto clone = clone_with_options(std::move(encryption_key_in), opts);
clone.content = Empty{};
clone.sign(cipher_suite, sig_priv, { { group_id, leaf_index } });
return clone;
}
LeafNode
LeafNode::for_commit(CipherSuite cipher_suite,
const bytes& group_id,
LeafIndex leaf_index,
HPKEPublicKey encryption_key_in,
const bytes& parent_hash,
const LeafNodeOptions& opts,
const SignaturePrivateKey& sig_priv) const
{
auto clone = clone_with_options(std::move(encryption_key_in), opts);
clone.content = ParentHash{ parent_hash };
clone.sign(cipher_suite, sig_priv, { { group_id, leaf_index } });
return clone;
}
LeafNodeSource
LeafNode::source() const
{
return tls::variant<LeafNodeSource>::type(content);
}
void
LeafNode::sign(CipherSuite cipher_suite,
const SignaturePrivateKey& sig_priv,
const std::optional<MemberBinding>& binding)
{
const auto tbs = to_be_signed(binding);
if (sig_priv.public_key != signature_key) {
throw InvalidParameterError("Signature key mismatch");
}
if (!credential.valid_for(signature_key)) {
throw InvalidParameterError("Credential not valid for signature key");
}
signature = sig_priv.sign(cipher_suite, sign_label::leaf_node, tbs);
}
bool
LeafNode::verify(CipherSuite cipher_suite,
const std::optional<MemberBinding>& binding) const
{
const auto tbs = to_be_signed(binding);
if (CredentialType::x509 == credential.type()) {
const auto& cred = credential.get<X509Credential>();
if (cred.signature_scheme() !=
tls_signature_scheme(cipher_suite.sig().id)) {
throw std::runtime_error("Signature algorithm invalid");
}
}
return signature_key.verify(
cipher_suite, sign_label::leaf_node, tbs, signature);
}
bool
LeafNode::verify_expiry(uint64_t now) const
{
const auto valid = overloaded{
[now](const Lifetime& lt) {
return lt.not_before <= now && now <= lt.not_after;
},
[](const auto& /* other */) { return false; },
};
return var::visit(valid, content);
}
bool
LeafNode::verify_extension_support(const ExtensionList& ext_list) const
{
// Verify that extensions in the list are supported
auto ext_types = stdx::transform<Extension::Type>(
ext_list.extensions, [](const auto& ext) { return ext.type; });
if (!capabilities.extensions_supported(ext_types)) {
return false;
}
// If there's a RequiredCapabilities extension, verify support
const auto maybe_req_capas = ext_list.find<RequiredCapabilitiesExtension>();
if (!maybe_req_capas) {
return true;
}
const auto& req_capas = opt::get(maybe_req_capas);
return capabilities.extensions_supported(req_capas.extensions) &&
capabilities.proposals_supported(req_capas.proposals);
}
LeafNode
LeafNode::clone_with_options(HPKEPublicKey encryption_key_in,
const LeafNodeOptions& opts) const
{
auto clone = *this;
clone.encryption_key = std::move(encryption_key_in);
if (opts.credential) {
clone.credential = opt::get(opts.credential);
}
if (opts.capabilities) {
clone.capabilities = opt::get(opts.capabilities);
}
if (opts.extensions) {
clone.extensions = opt::get(opts.extensions);
}
return clone;
}
// struct {
// HPKEPublicKey encryption_key;
// SignaturePublicKey signature_key;
// Credential credential;
// Capabilities capabilities;
//
// LeafNodeSource leaf_node_source;
// select (leaf_node_source) {
// case key_package:
// Lifetime lifetime;
//
// case update:
// struct{};
//
// case commit:
// opaque parent_hash<V>;
// }
//
// Extension extensions<V>;
//
// select (leaf_node_source) {
// case key_package:
// struct{};
//
// case update:
// opaque group_id<V>;
//
// case commit:
// opaque group_id<V>;
// }
// } LeafNodeTBS;
struct LeafNodeTBS
{
const HPKEPublicKey& encryption_key;
const SignaturePublicKey& signature_key;
const Credential& credential;
const Capabilities& capabilities;
const var::variant<Lifetime, Empty, ParentHash>& content;
const ExtensionList& extensions;
TLS_SERIALIZABLE(encryption_key,
signature_key,
credential,
capabilities,
content,
extensions)
TLS_TRAITS(tls::pass,
tls::pass,
tls::pass,
tls::pass,
tls::variant<LeafNodeSource>,
tls::pass)
};
bytes
LeafNode::to_be_signed(const std::optional<MemberBinding>& binding) const
{
tls::ostream w;
w << LeafNodeTBS{
encryption_key, signature_key, credential,
capabilities, content, extensions,
};
switch (source()) {
case LeafNodeSource::key_package:
break;
case LeafNodeSource::update:
case LeafNodeSource::commit:
w << opt::get(binding);
}
return w.bytes();
}
///
/// NodeType, ParentNode, and KeyPackage
///
bytes
ParentNode::hash(CipherSuite suite) const
{
return suite.digest().hash(tls::marshal(this));
}
KeyPackage::KeyPackage()
: version(ProtocolVersion::mls10)
, cipher_suite(CipherSuite::ID::unknown)
{
}
KeyPackage::KeyPackage(CipherSuite suite_in,
HPKEPublicKey init_key_in,
LeafNode leaf_node_in,
ExtensionList extensions_in,
const SignaturePrivateKey& sig_priv_in)
: version(ProtocolVersion::mls10)
, cipher_suite(suite_in)
, init_key(std::move(init_key_in))
, leaf_node(std::move(leaf_node_in))
, extensions(std::move(extensions_in))
{
grease(extensions);
sign(sig_priv_in);
}
KeyPackageRef
KeyPackage::ref() const
{
return cipher_suite.ref(*this);
}
void
KeyPackage::sign(const SignaturePrivateKey& sig_priv)
{
auto tbs = to_be_signed();
signature = sig_priv.sign(cipher_suite, sign_label::key_package, tbs);
}
bool
KeyPackage::verify() const
{
// Verify the inner leaf node
if (!leaf_node.verify(cipher_suite, std::nullopt)) {
return false;
}
// Check that the inner leaf node is intended for use in a KeyPackage
if (leaf_node.source() != LeafNodeSource::key_package) {
return false;
}
// Verify the KeyPackage
const auto tbs = to_be_signed();
if (CredentialType::x509 == leaf_node.credential.type()) {
const auto& cred = leaf_node.credential.get<X509Credential>();
if (cred.signature_scheme() !=
tls_signature_scheme(cipher_suite.sig().id)) {
throw std::runtime_error("Signature algorithm invalid");
}
}
return leaf_node.signature_key.verify(
cipher_suite, sign_label::key_package, tbs, signature);
}
bytes
KeyPackage::to_be_signed() const
{
tls::ostream out;
out << version << cipher_suite << init_key << leaf_node << extensions;
return out.bytes();
}
} // namespace mlspp

298
DPP/mlspp/src/credential.cpp Executable file
View File

@@ -0,0 +1,298 @@
#include <hpke/certificate.h>
#include <hpke/userinfo_vc.h>
#include <mls/credential.h>
#include <tls/tls_syntax.h>
namespace mlspp {
///
/// X509Credential
///
using mlspp::hpke::Certificate; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::Signature; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::UserInfoVC; // NOLINT(misc-unused-using-decls)
static const Signature&
find_signature(Signature::ID id)
{
switch (id) {
case Signature::ID::P256_SHA256:
return Signature::get<Signature::ID::P256_SHA256>();
case Signature::ID::P384_SHA384:
return Signature::get<Signature::ID::P384_SHA384>();
case Signature::ID::P521_SHA512:
return Signature::get<Signature::ID::P521_SHA512>();
case Signature::ID::Ed25519:
return Signature::get<Signature::ID::Ed25519>();
#if !defined(WITH_BORINGSSL)
case Signature::ID::Ed448:
return Signature::get<Signature::ID::Ed448>();
#endif
case Signature::ID::RSA_SHA256:
return Signature::get<Signature::ID::RSA_SHA256>();
default:
throw InvalidParameterError("Unsupported algorithm");
}
}
static std::vector<X509Credential::CertData>
bytes_to_x509_credential_data(const std::vector<bytes>& data_in)
{
return stdx::transform<X509Credential::CertData>(
data_in, [](const bytes& der) { return X509Credential::CertData{ der }; });
}
X509Credential::X509Credential(const std::vector<bytes>& der_chain_in)
: der_chain(bytes_to_x509_credential_data(der_chain_in))
{
if (der_chain.empty()) {
throw std::invalid_argument("empty certificate chain");
}
// Parse the chain
auto parsed = std::vector<Certificate>();
for (const auto& cert : der_chain) {
parsed.emplace_back(cert.data);
}
// first element represents leaf cert
const auto& sig = find_signature(parsed[0].public_key_algorithm());
const auto pub_data = sig.serialize(*parsed[0].public_key);
_signature_scheme = tls_signature_scheme(parsed[0].public_key_algorithm());
_public_key = SignaturePublicKey{ pub_data };
// verify chain for valid signatures
for (size_t i = 0; i < der_chain.size() - 1; i++) {
if (!parsed[i].valid_from(parsed[i + 1])) {
throw std::runtime_error("Certificate Chain validation failure");
}
}
}
SignatureScheme
X509Credential::signature_scheme() const
{
return _signature_scheme;
}
SignaturePublicKey
X509Credential::public_key() const
{
return _public_key;
}
bool
X509Credential::valid_for(const SignaturePublicKey& pub) const
{
return pub == public_key();
}
tls::ostream&
operator<<(tls::ostream& str, const X509Credential& obj)
{
return str << obj.der_chain;
}
tls::istream&
operator>>(tls::istream& str, X509Credential& obj)
{
auto der_chain = std::vector<X509Credential::CertData>{};
str >> der_chain;
auto der_in = stdx::transform<bytes>(
der_chain, [](const auto& cert_data) { return cert_data.data; });
obj = X509Credential(der_in);
return str;
}
bool
operator==(const X509Credential& lhs, const X509Credential& rhs)
{
return lhs.der_chain == rhs.der_chain;
}
///
/// UserInfoVCCredential
///
UserInfoVCCredential::UserInfoVCCredential(std::string userinfo_vc_jwt_in)
: userinfo_vc_jwt(std::move(userinfo_vc_jwt_in))
, _vc(std::make_shared<hpke::UserInfoVC>(userinfo_vc_jwt))
{
}
bool
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
UserInfoVCCredential::valid_for(const SignaturePublicKey& pub) const
{
const auto& vc_pub = _vc->public_key();
return pub.data == vc_pub.sig.serialize(*vc_pub.key);
}
bool
UserInfoVCCredential::valid_from(const PublicJWK& pub) const
{
const auto& sig = _vc->signature_algorithm();
if (pub.signature_scheme != tls_signature_scheme(sig.id)) {
return false;
}
const auto sig_pub = sig.deserialize(pub.public_key.data);
return _vc->valid_from(*sig_pub);
}
tls::ostream
operator<<(tls::ostream& str, const UserInfoVCCredential& obj)
{
return str << from_ascii(obj.userinfo_vc_jwt);
}
tls::istream
operator>>(tls::istream& str, UserInfoVCCredential& obj)
{
auto jwt = bytes{};
str >> jwt;
obj = UserInfoVCCredential(to_ascii(jwt));
return str;
}
bool
operator==(const UserInfoVCCredential& lhs, const UserInfoVCCredential& rhs)
{
return lhs.userinfo_vc_jwt == rhs.userinfo_vc_jwt;
}
bool
operator!=(const UserInfoVCCredential& lhs, const UserInfoVCCredential& rhs)
{
return !(lhs == rhs);
}
///
/// CredentialBinding and MultiCredential
///
struct CredentialBindingTBS
{
const CipherSuite& cipher_suite;
const Credential& credential;
const SignaturePublicKey& credential_key;
const SignaturePublicKey& signature_key;
TLS_SERIALIZABLE(cipher_suite, credential, credential_key, signature_key)
};
CredentialBinding::CredentialBinding(CipherSuite cipher_suite_in,
Credential credential_in,
const SignaturePrivateKey& credential_priv,
const SignaturePublicKey& signature_key)
: cipher_suite(cipher_suite_in)
, credential(std::move(credential_in))
, credential_key(credential_priv.public_key)
{
if (credential.type() == CredentialType::multi_draft_00) {
throw InvalidParameterError("Multi-credentials cannot be nested");
}
if (!credential.valid_for(credential_key)) {
throw InvalidParameterError("Credential key does not match credential");
}
signature = credential_priv.sign(
cipher_suite, sign_label::multi_credential, to_be_signed(signature_key));
}
bytes
CredentialBinding::to_be_signed(const SignaturePublicKey& signature_key) const
{
return tls::marshal(CredentialBindingTBS{
cipher_suite, credential, credential_key, signature_key });
}
bool
CredentialBinding::valid_for(const SignaturePublicKey& signature_key) const
{
auto valid_self = credential.valid_for(credential_key);
auto valid_other = credential_key.verify(cipher_suite,
sign_label::multi_credential,
to_be_signed(signature_key),
signature);
return valid_self && valid_other;
}
MultiCredential::MultiCredential(
const std::vector<CredentialBindingInput>& binding_inputs,
const SignaturePublicKey& signature_key)
{
bindings =
stdx::transform<CredentialBinding>(binding_inputs, [&](auto&& input) {
return CredentialBinding(input.cipher_suite,
input.credential,
input.credential_priv,
signature_key);
});
}
bool
MultiCredential::valid_for(const SignaturePublicKey& pub) const
{
return stdx::all_of(
bindings, [&](const auto& binding) { return binding.valid_for(pub); });
}
///
/// Credential
///
CredentialType
Credential::type() const
{
return tls::variant<CredentialType>::type(_cred);
}
Credential
Credential::basic(const bytes& identity)
{
return { BasicCredential{ identity } };
}
Credential
Credential::x509(const std::vector<bytes>& der_chain)
{
return { X509Credential{ der_chain } };
}
Credential
Credential::multi(const std::vector<CredentialBindingInput>& binding_inputs,
const SignaturePublicKey& signature_key)
{
return { MultiCredential{ binding_inputs, signature_key } };
}
Credential
Credential::userinfo_vc(const std::string& userinfo_vc_jwt)
{
return { UserInfoVCCredential{ userinfo_vc_jwt } };
}
bool
Credential::valid_for(const SignaturePublicKey& pub) const
{
const auto pub_key_match = overloaded{
[&](const X509Credential& x509) { return x509.valid_for(pub); },
[](const BasicCredential& /* basic */) { return true; },
[&](const UserInfoVCCredential& vc) { return vc.valid_for(pub); },
[&](const MultiCredential& multi) { return multi.valid_for(pub); },
};
return var::visit(pub_key_match, _cred);
}
Credential::Credential(SpecificCredential specific)
: _cred(std::move(specific))
{
}
} // namespace mlspp

498
DPP/mlspp/src/crypto.cpp Executable file
View File

@@ -0,0 +1,498 @@
#include <mls/core_types.h>
#include <mls/crypto.h>
#include <mls/messages.h>
#include <string>
using mlspp::hpke::AEAD; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::Digest; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::HPKE; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::KDF; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::KEM; // NOLINT(misc-unused-using-decls)
using mlspp::hpke::Signature; // NOLINT(misc-unused-using-decls)
namespace mlspp {
SignatureScheme
tls_signature_scheme(Signature::ID id)
{
switch (id) {
case Signature::ID::P256_SHA256:
return SignatureScheme::ecdsa_secp256r1_sha256;
case Signature::ID::P384_SHA384:
return SignatureScheme::ecdsa_secp384r1_sha384;
case Signature::ID::P521_SHA512:
return SignatureScheme::ecdsa_secp521r1_sha512;
case Signature::ID::Ed25519:
return SignatureScheme::ed25519;
#if !defined(WITH_BORINGSSL)
case Signature::ID::Ed448:
return SignatureScheme::ed448;
#endif
case Signature::ID::RSA_SHA256:
return SignatureScheme::rsa_pkcs1_sha256;
default:
throw InvalidParameterError("Unsupported algorithm");
}
}
///
/// CipherSuites and details
///
CipherSuite::CipherSuite()
: id(ID::unknown)
{
}
CipherSuite::CipherSuite(ID id_in)
: id(id_in)
{
}
SignatureScheme
CipherSuite::signature_scheme() const
{
switch (id) {
case ID::X25519_AES128GCM_SHA256_Ed25519:
case ID::X25519_CHACHA20POLY1305_SHA256_Ed25519:
return SignatureScheme::ed25519;
case ID::P256_AES128GCM_SHA256_P256:
return SignatureScheme::ecdsa_secp256r1_sha256;
case ID::X448_AES256GCM_SHA512_Ed448:
case ID::X448_CHACHA20POLY1305_SHA512_Ed448:
return SignatureScheme::ed448;
case ID::P521_AES256GCM_SHA512_P521:
return SignatureScheme::ecdsa_secp521r1_sha512;
case ID::P384_AES256GCM_SHA384_P384:
return SignatureScheme::ecdsa_secp384r1_sha384;
default:
throw InvalidParameterError("Unsupported algorithm");
}
}
const CipherSuite::Ciphers&
CipherSuite::get() const
{
static const auto ciphers_X25519_AES128GCM_SHA256_Ed25519 =
CipherSuite::Ciphers{
HPKE(KEM::ID::DHKEM_X25519_SHA256,
KDF::ID::HKDF_SHA256,
AEAD::ID::AES_128_GCM),
Digest::get<Digest::ID::SHA256>(),
Signature::get<Signature::ID::Ed25519>(),
};
static const auto ciphers_P256_AES128GCM_SHA256_P256 = CipherSuite::Ciphers{
HPKE(
KEM::ID::DHKEM_P256_SHA256, KDF::ID::HKDF_SHA256, AEAD::ID::AES_128_GCM),
Digest::get<Digest::ID::SHA256>(),
Signature::get<Signature::ID::P256_SHA256>(),
};
static const auto ciphers_X25519_CHACHA20POLY1305_SHA256_Ed25519 =
CipherSuite::Ciphers{
HPKE(KEM::ID::DHKEM_X25519_SHA256,
KDF::ID::HKDF_SHA256,
AEAD::ID::CHACHA20_POLY1305),
Digest::get<Digest::ID::SHA256>(),
Signature::get<Signature::ID::Ed25519>(),
};
static const auto ciphers_P521_AES256GCM_SHA512_P521 = CipherSuite::Ciphers{
HPKE(
KEM::ID::DHKEM_P521_SHA512, KDF::ID::HKDF_SHA512, AEAD::ID::AES_256_GCM),
Digest::get<Digest::ID::SHA512>(),
Signature::get<Signature::ID::P521_SHA512>(),
};
static const auto ciphers_P384_AES256GCM_SHA384_P384 = CipherSuite::Ciphers{
HPKE(
KEM::ID::DHKEM_P384_SHA384, KDF::ID::HKDF_SHA384, AEAD::ID::AES_256_GCM),
Digest::get<Digest::ID::SHA384>(),
Signature::get<Signature::ID::P384_SHA384>(),
};
#if !defined(WITH_BORINGSSL)
static const auto ciphers_X448_AES256GCM_SHA512_Ed448 = CipherSuite::Ciphers{
HPKE(
KEM::ID::DHKEM_X448_SHA512, KDF::ID::HKDF_SHA512, AEAD::ID::AES_256_GCM),
Digest::get<Digest::ID::SHA512>(),
Signature::get<Signature::ID::Ed448>(),
};
static const auto ciphers_X448_CHACHA20POLY1305_SHA512_Ed448 =
CipherSuite::Ciphers{
HPKE(KEM::ID::DHKEM_X448_SHA512,
KDF::ID::HKDF_SHA512,
AEAD::ID::CHACHA20_POLY1305),
Digest::get<Digest::ID::SHA512>(),
Signature::get<Signature::ID::Ed448>(),
};
#endif
switch (id) {
case ID::unknown:
throw InvalidParameterError("Uninitialized ciphersuite");
case ID::X25519_AES128GCM_SHA256_Ed25519:
return ciphers_X25519_AES128GCM_SHA256_Ed25519;
case ID::P256_AES128GCM_SHA256_P256:
return ciphers_P256_AES128GCM_SHA256_P256;
case ID::X25519_CHACHA20POLY1305_SHA256_Ed25519:
return ciphers_X25519_CHACHA20POLY1305_SHA256_Ed25519;
case ID::P521_AES256GCM_SHA512_P521:
return ciphers_P521_AES256GCM_SHA512_P521;
case ID::P384_AES256GCM_SHA384_P384:
return ciphers_P384_AES256GCM_SHA384_P384;
#if !defined(WITH_BORINGSSL)
case ID::X448_AES256GCM_SHA512_Ed448:
return ciphers_X448_AES256GCM_SHA512_Ed448;
case ID::X448_CHACHA20POLY1305_SHA512_Ed448:
return ciphers_X448_CHACHA20POLY1305_SHA512_Ed448;
#endif
default:
throw InvalidParameterError("Unsupported ciphersuite");
}
}
struct HKDFLabel
{
uint16_t length;
bytes label;
bytes context;
TLS_SERIALIZABLE(length, label, context)
};
bytes
CipherSuite::expand_with_label(const bytes& secret,
const std::string& label,
const bytes& context,
size_t length) const
{
auto mls_label = from_ascii(std::string("MLS 1.0 ") + label);
auto length16 = static_cast<uint16_t>(length);
auto label_bytes = tls::marshal(HKDFLabel{ length16, mls_label, context });
return get().hpke.kdf.expand(secret, label_bytes, length);
}
bytes
CipherSuite::derive_secret(const bytes& secret, const std::string& label) const
{
return expand_with_label(secret, label, {}, secret_size());
}
bytes
CipherSuite::derive_tree_secret(const bytes& secret,
const std::string& label,
uint32_t generation,
size_t length) const
{
return expand_with_label(secret, label, tls::marshal(generation), length);
}
#if WITH_BORINGSSL
const std::array<CipherSuite::ID, 5> all_supported_suites = {
CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519,
CipherSuite::ID::P256_AES128GCM_SHA256_P256,
CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519,
CipherSuite::ID::P521_AES256GCM_SHA512_P521,
CipherSuite::ID::P384_AES256GCM_SHA384_P384,
};
#else
const std::array<CipherSuite::ID, 7> all_supported_suites = {
CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519,
CipherSuite::ID::P256_AES128GCM_SHA256_P256,
CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519,
CipherSuite::ID::P521_AES256GCM_SHA512_P521,
CipherSuite::ID::P384_AES256GCM_SHA384_P384,
CipherSuite::ID::X448_CHACHA20POLY1305_SHA512_Ed448,
CipherSuite::ID::X448_AES256GCM_SHA512_Ed448,
};
#endif
// MakeKeyPackageRef(value) = KDF.expand(
// KDF.extract("", value), "MLS 1.0 KeyPackage Reference", 16)
template<>
const bytes&
CipherSuite::reference_label<KeyPackage>()
{
static const auto label = from_ascii("MLS 1.0 KeyPackage Reference");
return label;
}
// MakeProposalRef(value) = KDF.expand(
// KDF.extract("", value), "MLS 1.0 Proposal Reference", 16)
//
// Even though the label says "Proposal", we actually hash the entire enclosing
// AuthenticatedContent object.
template<>
const bytes&
CipherSuite::reference_label<AuthenticatedContent>()
{
static const auto label = from_ascii("MLS 1.0 Proposal Reference");
return label;
}
///
/// HPKEPublicKey and HPKEPrivateKey
///
// This function produces a non-literal type, so it can't be constexpr.
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define MLS_1_0_PLUS(label) from_ascii("MLS 1.0 " label)
static bytes
mls_1_0_plus(const std::string& label)
{
auto plus = "MLS 1.0 "s + label;
return from_ascii(plus);
}
namespace encrypt_label {
const std::string update_path_node = "UpdatePathNode";
const std::string welcome = "Welcome";
} // namespace encrypt_label
struct EncryptContext
{
const bytes& label;
const bytes& content;
TLS_SERIALIZABLE(label, content)
};
HPKECiphertext
HPKEPublicKey::encrypt(CipherSuite suite,
const std::string& label,
const bytes& context,
const bytes& pt) const
{
auto label_plus = mls_1_0_plus(label);
auto encrypt_context = tls::marshal(EncryptContext{ label_plus, context });
auto pkR = suite.hpke().kem.deserialize(data);
auto [enc, ctx] = suite.hpke().setup_base_s(*pkR, encrypt_context);
auto ct = ctx.seal({}, pt);
return HPKECiphertext{ enc, ct };
}
std::tuple<bytes, bytes>
HPKEPublicKey::do_export(CipherSuite suite,
const bytes& info,
const std::string& label,
size_t size) const
{
auto label_data = from_ascii(label);
auto pkR = suite.hpke().kem.deserialize(data);
auto [enc, ctx] = suite.hpke().setup_base_s(*pkR, info);
auto exported = ctx.do_export(label_data, size);
return std::make_tuple(enc, exported);
}
HPKEPrivateKey
HPKEPrivateKey::generate(CipherSuite suite)
{
auto priv = suite.hpke().kem.generate_key_pair();
auto priv_data = suite.hpke().kem.serialize_private(*priv);
auto pub = priv->public_key();
auto pub_data = suite.hpke().kem.serialize(*pub);
return { priv_data, pub_data };
}
HPKEPrivateKey
HPKEPrivateKey::parse(CipherSuite suite, const bytes& data)
{
auto priv = suite.hpke().kem.deserialize_private(data);
auto pub = priv->public_key();
auto pub_data = suite.hpke().kem.serialize(*pub);
return { data, pub_data };
}
HPKEPrivateKey
HPKEPrivateKey::derive(CipherSuite suite, const bytes& secret)
{
auto priv = suite.hpke().kem.derive_key_pair(secret);
auto priv_data = suite.hpke().kem.serialize_private(*priv);
auto pub = priv->public_key();
auto pub_data = suite.hpke().kem.serialize(*pub);
return { priv_data, pub_data };
}
bytes
HPKEPrivateKey::decrypt(CipherSuite suite,
const std::string& label,
const bytes& context,
const HPKECiphertext& ct) const
{
auto label_plus = mls_1_0_plus(label);
auto encrypt_context = tls::marshal(EncryptContext{ label_plus, context });
auto skR = suite.hpke().kem.deserialize_private(data);
auto ctx = suite.hpke().setup_base_r(ct.kem_output, *skR, encrypt_context);
auto pt = ctx.open({}, ct.ciphertext);
if (!pt) {
throw InvalidParameterError("HPKE decryption failure");
}
return opt::get(pt);
}
bytes
HPKEPrivateKey::do_export(CipherSuite suite,
const bytes& info,
const bytes& kem_output,
const std::string& label,
size_t size) const
{
auto label_data = from_ascii(label);
auto skR = suite.hpke().kem.deserialize_private(data);
auto ctx = suite.hpke().setup_base_r(kem_output, *skR, info);
return ctx.do_export(label_data, size);
}
HPKEPrivateKey::HPKEPrivateKey(bytes priv_data, bytes pub_data)
: data(std::move(priv_data))
, public_key{ std::move(pub_data) }
{
}
void
HPKEPrivateKey::set_public_key(CipherSuite suite)
{
const auto priv = suite.hpke().kem.deserialize_private(data);
auto pub = priv->public_key();
public_key.data = suite.hpke().kem.serialize(*pub);
}
///
/// SignaturePublicKey and SignaturePrivateKey
///
namespace sign_label {
const std::string mls_content = "FramedContentTBS";
const std::string leaf_node = "LeafNodeTBS";
const std::string key_package = "KeyPackageTBS";
const std::string group_info = "GroupInfoTBS";
const std::string multi_credential = "MultiCredential";
} // namespace sign_label
struct SignContent
{
const bytes& label;
const bytes& content;
TLS_SERIALIZABLE(label, content)
};
bool
SignaturePublicKey::verify(const CipherSuite& suite,
const std::string& label,
const bytes& message,
const bytes& signature) const
{
auto label_plus = mls_1_0_plus(label);
const auto content = tls::marshal(SignContent{ label_plus, message });
auto pub = suite.sig().deserialize(data);
return suite.sig().verify(content, signature, *pub);
}
SignaturePublicKey
SignaturePublicKey::from_jwk(CipherSuite suite, const std::string& json_str)
{
auto pub = suite.sig().import_jwk(json_str);
auto pub_data = suite.sig().serialize(*pub);
return SignaturePublicKey{ pub_data };
}
std::string
SignaturePublicKey::to_jwk(CipherSuite suite) const
{
auto pub = suite.sig().deserialize(data);
return suite.sig().export_jwk(*pub);
}
PublicJWK
PublicJWK::parse(const std::string& jwk_json)
{
const auto parsed = Signature::parse_jwk(jwk_json);
const auto scheme = tls_signature_scheme(parsed.sig.id);
const auto pub_data = parsed.sig.serialize(*parsed.key);
return { scheme, parsed.key_id, { pub_data } };
}
SignaturePrivateKey
SignaturePrivateKey::generate(CipherSuite suite)
{
auto priv = suite.sig().generate_key_pair();
auto priv_data = suite.sig().serialize_private(*priv);
auto pub = priv->public_key();
auto pub_data = suite.sig().serialize(*pub);
return { priv_data, pub_data };
}
SignaturePrivateKey
SignaturePrivateKey::parse(CipherSuite suite, const bytes& data)
{
auto priv = suite.sig().deserialize_private(data);
auto pub = priv->public_key();
auto pub_data = suite.sig().serialize(*pub);
return { data, pub_data };
}
SignaturePrivateKey
SignaturePrivateKey::derive(CipherSuite suite, const bytes& secret)
{
auto priv = suite.sig().derive_key_pair(secret);
auto priv_data = suite.sig().serialize_private(*priv);
auto pub = priv->public_key();
auto pub_data = suite.sig().serialize(*pub);
return { priv_data, pub_data };
}
bytes
SignaturePrivateKey::sign(const CipherSuite& suite,
const std::string& label,
const bytes& message) const
{
auto label_plus = mls_1_0_plus(label);
const auto content = tls::marshal(SignContent{ label_plus, message });
const auto priv = suite.sig().deserialize_private(data);
return suite.sig().sign(content, *priv);
}
SignaturePrivateKey::SignaturePrivateKey(bytes priv_data, bytes pub_data)
: data(std::move(priv_data))
, public_key{ std::move(pub_data) }
{
}
void
SignaturePrivateKey::set_public_key(CipherSuite suite)
{
const auto priv = suite.sig().deserialize_private(data);
auto pub = priv->public_key();
public_key.data = suite.sig().serialize(*pub);
}
SignaturePrivateKey
SignaturePrivateKey::from_jwk(CipherSuite suite, const std::string& json_str)
{
auto priv = suite.sig().import_jwk_private(json_str);
auto priv_data = suite.sig().serialize_private(*priv);
auto pub = priv->public_key();
auto pub_data = suite.sig().serialize(*pub);
return { priv_data, pub_data };
}
std::string
SignaturePrivateKey::to_jwk(CipherSuite suite) const
{
const auto priv = suite.sig().deserialize_private(data);
return suite.sig().export_jwk_private(*priv);
}
} // namespace mlspp

126
DPP/mlspp/src/grease.cpp Executable file
View File

@@ -0,0 +1,126 @@
#include "grease.h"
#include <random>
#include <set>
namespace mlspp {
#ifdef DISABLE_GREASE
void
grease([[maybe_unused]] Capabilities& capabilities,
[[maybe_unused]] const ExtensionList& extensions)
{
}
void
grease([[maybe_unused]] ExtensionList& extensions)
{
}
#else
// Randomness parmeters:
// * Given a list of N items, insert max(1, rand(p_grease * N)) GREASE values
// * Each GREASE value added is distinct, unless more than 15 values are needed
// * For extensions, each GREASE extension has rand(n_grease_ext) random bytes
// of data
const size_t log_p_grease = 1; // -log2(p_grease) => p_grease = 1/2
const size_t max_grease_ext_size = 16;
const std::array<uint16_t, 15> grease_values = { 0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A,
0x4A4A, 0x5A5A, 0x6A6A, 0x7A7A,
0x8A8A, 0x9A9A, 0xAAAA, 0xBABA,
0xCACA, 0xDADA, 0xEAEA };
static size_t
rand_int(size_t n)
{
static auto seed = std::random_device()();
static auto rng = std::mt19937(seed);
return std::uniform_int_distribution<size_t>(0, n)(rng);
}
static uint16_t
grease_value()
{
const auto where = rand_int(grease_values.size() - 1);
return grease_values.at(where);
}
static bool
grease_value(uint16_t val)
{
static constexpr auto grease_mask = uint16_t(0x0F0F);
return ((val & grease_mask) == 0x0A0A) && val != 0xFAFA;
}
static std::set<uint16_t>
grease_sample(size_t count)
{
auto vals = std::set<uint16_t>{};
while (vals.size() < count) {
uint16_t val = grease_value();
while (vals.count(val) > 0 && vals.size() < grease_values.size()) {
val = grease_value();
}
vals.insert(val);
}
return vals;
}
template<typename T>
static void
grease(std::vector<T>& vec)
{
const auto count = std::max(size_t(1), rand_int(vec.size() >> log_p_grease));
for (const auto val : grease_sample(count)) {
const auto where = static_cast<ptrdiff_t>(rand_int(vec.size()));
vec.insert(std::begin(vec) + where, static_cast<T>(val));
}
}
void
grease(Capabilities& capabilities, const ExtensionList& extensions)
{
// Add GREASE to the appropriate portions of the capabilities
grease(capabilities.cipher_suites);
grease(capabilities.extensions);
grease(capabilities.proposals);
grease(capabilities.credentials);
// Ensure that the GREASE extensions are reflected in Capabilities.extensions
for (const auto& ext : extensions.extensions) {
if (!grease_value(ext.type)) {
continue;
}
if (stdx::contains(capabilities.extensions, ext.type)) {
continue;
}
const auto where =
static_cast<ptrdiff_t>(rand_int(capabilities.extensions.size()));
const auto where_ptr = std::begin(capabilities.extensions) + where;
capabilities.extensions.insert(where_ptr, ext.type);
}
}
void
grease(ExtensionList& extensions)
{
auto& ext = extensions.extensions;
const auto count = std::max(size_t(1), rand_int(ext.size() >> log_p_grease));
for (const auto ext_type : grease_sample(count)) {
const auto where = static_cast<ptrdiff_t>(rand_int(ext.size()));
auto ext_data = random_bytes(rand_int(max_grease_ext_size));
ext.insert(std::begin(ext) + where, { ext_type, std::move(ext_data) });
}
}
#endif // DISABLE_GREASE
} // namespace mlspp

13
DPP/mlspp/src/grease.h Executable file
View File

@@ -0,0 +1,13 @@
#pragma once
#include "mls/core_types.h"
namespace mlspp {
void
grease(Capabilities& capabilities, const ExtensionList& extensions);
void
grease(ExtensionList& extensions);
} // namespace mlspp

579
DPP/mlspp/src/key_schedule.cpp Executable file
View File

@@ -0,0 +1,579 @@
#include <mls/key_schedule.h>
namespace mlspp {
///
/// Key Derivation Functions
///
struct TreeContext
{
NodeIndex node;
uint32_t generation = 0;
TLS_SERIALIZABLE(node, generation)
};
///
/// HashRatchet
///
HashRatchet::HashRatchet(CipherSuite suite_in, bytes base_secret_in)
: suite(suite_in)
, next_secret(std::move(base_secret_in))
, next_generation(0)
, key_size(suite.hpke().aead.key_size)
, nonce_size(suite.hpke().aead.nonce_size)
, secret_size(suite.secret_size())
{
}
std::tuple<uint32_t, KeyAndNonce>
HashRatchet::next()
{
auto generation = next_generation;
auto key = suite.derive_tree_secret(next_secret, "key", generation, key_size);
auto nonce =
suite.derive_tree_secret(next_secret, "nonce", generation, nonce_size);
auto secret =
suite.derive_tree_secret(next_secret, "secret", generation, secret_size);
next_generation += 1;
next_secret = secret;
cache[generation] = { key, nonce };
return { generation, cache.at(generation) };
}
// Note: This construction deliberately does not preserve the forward-secrecy
// invariant, in that keys/nonces are not deleted after they are used.
// Otherwise, it would not be possible for a node to send to itself. Keys can
// be deleted once they are not needed by calling HashRatchet::erase().
KeyAndNonce
HashRatchet::get(uint32_t generation)
{
if (cache.count(generation) > 0) {
auto out = cache.at(generation);
return out;
}
if (next_generation > generation) {
throw ProtocolError("Request for expired key");
}
while (next_generation <= generation) {
next();
}
return cache.at(generation);
}
void
HashRatchet::erase(uint32_t generation)
{
if (cache.count(generation) == 0) {
return;
}
cache.erase(generation);
}
///
/// SecretTree
///
SecretTree::SecretTree(CipherSuite suite_in,
LeafCount group_size_in,
bytes encryption_secret_in)
: suite(suite_in)
, group_size(LeafCount::full(group_size_in))
, root(NodeIndex::root(group_size))
, secret_size(suite_in.secret_size())
{
secrets.emplace(root, std::move(encryption_secret_in));
}
bytes
SecretTree::get(LeafIndex sender)
{
static const auto context_left = from_ascii("left");
static const auto context_right = from_ascii("right");
auto node = NodeIndex(sender);
// Find an ancestor that is populated
auto dirpath = node.dirpath(group_size);
dirpath.insert(dirpath.begin(), node);
dirpath.push_back(root);
uint32_t curr = 0;
for (; curr < dirpath.size(); ++curr) {
auto i = dirpath.at(curr);
if (secrets.count(i) > 0) {
break;
}
}
if (curr > dirpath.size()) {
throw InvalidParameterError("No secret found to derive base key");
}
// Derive down
for (; curr > 0; --curr) {
auto curr_node = dirpath.at(curr);
auto left = curr_node.left();
auto right = curr_node.right();
auto& secret = secrets.at(curr_node);
const auto left_secret =
suite.expand_with_label(secret, "tree", context_left, secret_size);
const auto right_secret =
suite.expand_with_label(secret, "tree", context_right, secret_size);
secrets.insert_or_assign(left, left_secret);
secrets.insert_or_assign(right, right_secret);
}
// Copy the leaf
auto out = secrets.at(node);
// Zeroize along the direct path
for (auto i : dirpath) {
secrets.erase(i);
}
return out;
}
///
/// ReuseGuard
///
static ReuseGuard
new_reuse_guard()
{
auto random = random_bytes(4);
auto guard = ReuseGuard();
std::copy(random.begin(), random.end(), guard.begin());
return guard;
}
static void
apply_reuse_guard(const ReuseGuard& guard, bytes& nonce)
{
for (size_t i = 0; i < guard.size(); i++) {
nonce.at(i) ^= guard.at(i);
}
}
///
/// GroupKeySource
///
GroupKeySource::GroupKeySource(CipherSuite suite_in,
LeafCount group_size,
bytes encryption_secret)
: suite(suite_in)
, secret_tree(suite, group_size, std::move(encryption_secret))
{
}
HashRatchet&
GroupKeySource::chain(ContentType type, LeafIndex sender)
{
switch (type) {
case ContentType::proposal:
case ContentType::commit:
return chain(RatchetType::handshake, sender);
case ContentType::application:
return chain(RatchetType::application, sender);
default:
throw InvalidParameterError("Invalid content type");
}
}
HashRatchet&
GroupKeySource::chain(RatchetType type, LeafIndex sender)
{
auto key = Key{ type, sender };
if (chains.count(key) > 0) {
return chains[key];
}
auto secret_size = suite.secret_size();
auto leaf_secret = secret_tree.get(sender);
auto handshake_secret =
suite.expand_with_label(leaf_secret, "handshake", {}, secret_size);
auto application_secret =
suite.expand_with_label(leaf_secret, "application", {}, secret_size);
chains.emplace(Key{ RatchetType::handshake, sender },
HashRatchet{ suite, handshake_secret });
chains.emplace(Key{ RatchetType::application, sender },
HashRatchet{ suite, application_secret });
return chains[key];
}
std::tuple<uint32_t, ReuseGuard, KeyAndNonce>
GroupKeySource::next(ContentType type, LeafIndex sender)
{
auto [generation, keys] = chain(type, sender).next();
auto reuse_guard = new_reuse_guard();
apply_reuse_guard(reuse_guard, keys.nonce);
return { generation, reuse_guard, keys };
}
KeyAndNonce
GroupKeySource::get(ContentType type,
LeafIndex sender,
uint32_t generation,
ReuseGuard reuse_guard)
{
auto keys = chain(type, sender).get(generation);
apply_reuse_guard(reuse_guard, keys.nonce);
return keys;
}
void
GroupKeySource::erase(ContentType type, LeafIndex sender, uint32_t generation)
{
return chain(type, sender).erase(generation);
}
// struct {
// opaque group_id<0..255>;
// uint64 epoch;
// ContentType content_type;
// opaque authenticated_data<0..2^32-1>;
// } ContentAAD;
struct ContentAAD
{
const bytes& group_id;
const epoch_t epoch;
const ContentType content_type;
const bytes& authenticated_data;
TLS_SERIALIZABLE(group_id, epoch, content_type, authenticated_data)
};
///
/// KeyScheduleEpoch
///
struct PSKLabel
{
const PreSharedKeyID& id;
uint16_t index;
uint16_t count;
TLS_SERIALIZABLE(id, index, count);
};
static bytes
make_joiner_secret(CipherSuite suite,
const bytes& context,
const bytes& init_secret,
const bytes& commit_secret)
{
auto pre_joiner_secret = suite.hpke().kdf.extract(init_secret, commit_secret);
return suite.expand_with_label(
pre_joiner_secret, "joiner", context, suite.secret_size());
}
static bytes
make_epoch_secret(CipherSuite suite,
const bytes& joiner_secret,
const bytes& psk_secret,
const bytes& context)
{
auto member_secret = suite.hpke().kdf.extract(joiner_secret, psk_secret);
return suite.expand_with_label(
member_secret, "epoch", context, suite.secret_size());
}
KeyScheduleEpoch
KeyScheduleEpoch::joiner(CipherSuite suite_in,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks,
const bytes& context)
{
return { suite_in, joiner_secret, make_psk_secret(suite_in, psks), context };
}
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
const bytes& joiner_secret,
const bytes& psk_secret,
const bytes& context)
: suite(suite_in)
, joiner_secret(joiner_secret)
, epoch_secret(
make_epoch_secret(suite_in, joiner_secret, psk_secret, context))
, sender_data_secret(suite.derive_secret(epoch_secret, "sender data"))
, encryption_secret(suite.derive_secret(epoch_secret, "encryption"))
, exporter_secret(suite.derive_secret(epoch_secret, "exporter"))
, epoch_authenticator(suite.derive_secret(epoch_secret, "authentication"))
, external_secret(suite.derive_secret(epoch_secret, "external"))
, confirmation_key(suite.derive_secret(epoch_secret, "confirm"))
, membership_key(suite.derive_secret(epoch_secret, "membership"))
, resumption_psk(suite.derive_secret(epoch_secret, "resumption"))
, init_secret(suite.derive_secret(epoch_secret, "init"))
, external_priv(HPKEPrivateKey::derive(suite, external_secret))
{
}
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in)
: suite(suite_in)
{
}
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
const bytes& init_secret,
const bytes& context)
: KeyScheduleEpoch(
suite_in,
make_joiner_secret(suite_in, context, init_secret, suite_in.zero()),
{ /* no PSKs */ },
context)
{
}
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
const bytes& init_secret,
const bytes& commit_secret,
const bytes& psk_secret,
const bytes& context)
: KeyScheduleEpoch(
suite_in,
make_joiner_secret(suite_in, context, init_secret, commit_secret),
psk_secret,
context)
{
}
std::tuple<bytes, bytes>
KeyScheduleEpoch::external_init(CipherSuite suite,
const HPKEPublicKey& external_pub)
{
auto size = suite.secret_size();
return external_pub.do_export(
suite, {}, "MLS 1.0 external init secret", size);
}
bytes
KeyScheduleEpoch::receive_external_init(const bytes& kem_output) const
{
auto size = suite.secret_size();
return external_priv.do_export(
suite, {}, kem_output, "MLS 1.0 external init secret", size);
}
KeyScheduleEpoch
KeyScheduleEpoch::next(const bytes& commit_secret,
const std::vector<PSKWithSecret>& psks,
const std::optional<bytes>& force_init_secret,
const bytes& context) const
{
return next_raw(
commit_secret, make_psk_secret(suite, psks), force_init_secret, context);
}
KeyScheduleEpoch
KeyScheduleEpoch::next_raw(const bytes& commit_secret,
const bytes& psk_secret,
const std::optional<bytes>& force_init_secret,
const bytes& context) const
{
auto actual_init_secret = init_secret;
if (force_init_secret) {
actual_init_secret = opt::get(force_init_secret);
}
return { suite, actual_init_secret, commit_secret, psk_secret, context };
}
GroupKeySource
KeyScheduleEpoch::encryption_keys(LeafCount size) const
{
return { suite, size, encryption_secret };
}
bytes
KeyScheduleEpoch::confirmation_tag(const bytes& confirmed_transcript_hash) const
{
return suite.digest().hmac(confirmation_key, confirmed_transcript_hash);
}
bytes
KeyScheduleEpoch::do_export(const std::string& label,
const bytes& context,
size_t size) const
{
auto secret = suite.derive_secret(exporter_secret, label);
auto context_hash = suite.digest().hash(context);
return suite.expand_with_label(secret, "exported", context_hash, size);
}
PSKWithSecret
KeyScheduleEpoch::resumption_psk_w_secret(ResumptionPSKUsage usage,
const bytes& group_id,
epoch_t epoch)
{
auto nonce = random_bytes(suite.secret_size());
auto psk = ResumptionPSK{ usage, group_id, epoch };
return { { psk, nonce }, resumption_psk };
}
bytes
KeyScheduleEpoch::make_psk_secret(CipherSuite suite,
const std::vector<PSKWithSecret>& psks)
{
auto psk_secret = suite.zero();
auto count = uint16_t(psks.size());
auto index = uint16_t(0);
for (const auto& psk : psks) {
auto psk_extracted = suite.hpke().kdf.extract(suite.zero(), psk.secret);
auto psk_label = tls::marshal(PSKLabel{ psk.id, index, count });
auto psk_input = suite.expand_with_label(
psk_extracted, "derived psk", psk_label, suite.secret_size());
psk_secret = suite.hpke().kdf.extract(psk_input, psk_secret);
index += 1;
}
return psk_secret;
}
bytes
KeyScheduleEpoch::welcome_secret(CipherSuite suite,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks)
{
auto psk_secret = make_psk_secret(suite, psks);
return welcome_secret_raw(suite, joiner_secret, psk_secret);
}
bytes
KeyScheduleEpoch::welcome_secret_raw(CipherSuite suite,
const bytes& joiner_secret,
const bytes& psk_secret)
{
auto extract = suite.hpke().kdf.extract(joiner_secret, psk_secret);
return suite.derive_secret(extract, "welcome");
}
KeyAndNonce
KeyScheduleEpoch::sender_data_keys(CipherSuite suite,
const bytes& sender_data_secret,
const bytes& ciphertext)
{
auto sample_size = suite.secret_size();
auto sample = bytes(sample_size);
if (ciphertext.size() <= sample_size) {
sample = ciphertext;
} else {
sample = ciphertext.slice(0, sample_size);
}
auto key_size = suite.hpke().aead.key_size;
auto nonce_size = suite.hpke().aead.nonce_size;
return {
suite.expand_with_label(sender_data_secret, "key", sample, key_size),
suite.expand_with_label(sender_data_secret, "nonce", sample, nonce_size),
};
}
bool
operator==(const KeyScheduleEpoch& lhs, const KeyScheduleEpoch& rhs)
{
auto epoch_secret = (lhs.epoch_secret == rhs.epoch_secret);
auto sender_data_secret = (lhs.sender_data_secret == rhs.sender_data_secret);
auto encryption_secret = (lhs.encryption_secret == rhs.encryption_secret);
auto exporter_secret = (lhs.exporter_secret == rhs.exporter_secret);
auto confirmation_key = (lhs.confirmation_key == rhs.confirmation_key);
auto init_secret = (lhs.init_secret == rhs.init_secret);
auto external_priv = (lhs.external_priv == rhs.external_priv);
return epoch_secret && sender_data_secret && encryption_secret &&
exporter_secret && confirmation_key && init_secret && external_priv;
}
// struct {
// WireFormat wire_format;
// GroupContent content; // with content.content_type == commit
// opaque signature<V>;
// } ConfirmedTranscriptHashInput;
struct ConfirmedTranscriptHashInput
{
WireFormat wire_format;
const GroupContent& content;
const bytes& signature;
TLS_SERIALIZABLE(wire_format, content, signature)
};
// struct {
// MAC confirmation_tag;
// } InterimTranscriptHashInput;
struct InterimTranscriptHashInput
{
bytes confirmation_tag;
TLS_SERIALIZABLE(confirmation_tag)
};
TranscriptHash::TranscriptHash(CipherSuite suite_in)
: suite(suite_in)
{
}
TranscriptHash::TranscriptHash(CipherSuite suite_in,
bytes confirmed_in,
const bytes& confirmation_tag)
: suite(suite_in)
, confirmed(std::move(confirmed_in))
{
update_interim(confirmation_tag);
}
void
TranscriptHash::update(const AuthenticatedContent& content_auth)
{
update_confirmed(content_auth);
update_interim(content_auth);
}
void
TranscriptHash::update_confirmed(const AuthenticatedContent& content_auth)
{
const auto transcript =
interim + content_auth.confirmed_transcript_hash_input();
confirmed = suite.digest().hash(transcript);
}
void
TranscriptHash::update_interim(const bytes& confirmation_tag)
{
const auto transcript = confirmed + tls::marshal(confirmation_tag);
interim = suite.digest().hash(transcript);
}
void
TranscriptHash::update_interim(const AuthenticatedContent& content_auth)
{
const auto transcript =
confirmed + content_auth.interim_transcript_hash_input();
interim = suite.digest().hash(transcript);
}
bool
operator==(const TranscriptHash& lhs, const TranscriptHash& rhs)
{
auto confirmed = (lhs.confirmed == rhs.confirmed);
auto interim = (lhs.interim == rhs.interim);
return confirmed && interim;
}
} // namespace mlspp

947
DPP/mlspp/src/messages.cpp Executable file
View File

@@ -0,0 +1,947 @@
#include <mls/key_schedule.h>
#include <mls/messages.h>
#include <mls/state.h>
#include <mls/treekem.h>
#include "grease.h"
namespace mlspp {
// Extensions
const Extension::Type ExternalPubExtension::type = ExtensionType::external_pub;
const Extension::Type RatchetTreeExtension::type = ExtensionType::ratchet_tree;
const Extension::Type ExternalSendersExtension::type =
ExtensionType::external_senders;
const Extension::Type SFrameParameters::type = ExtensionType::sframe_parameters;
const Extension::Type SFrameCapabilities::type =
ExtensionType::sframe_parameters;
bool
SFrameCapabilities::compatible(const SFrameParameters& params) const
{
return stdx::contains(cipher_suites, params.cipher_suite);
}
// GroupContext
GroupContext::GroupContext(CipherSuite cipher_suite_in,
bytes group_id_in,
epoch_t epoch_in,
bytes tree_hash_in,
bytes confirmed_transcript_hash_in,
ExtensionList extensions_in)
: cipher_suite(cipher_suite_in)
, group_id(std::move(group_id_in))
, epoch(epoch_in)
, tree_hash(std::move(tree_hash_in))
, confirmed_transcript_hash(std::move(confirmed_transcript_hash_in))
, extensions(std::move(extensions_in))
{
}
// GroupInfo
GroupInfo::GroupInfo(GroupContext group_context_in,
ExtensionList extensions_in,
bytes confirmation_tag_in)
: group_context(std::move(group_context_in))
, extensions(std::move(extensions_in))
, confirmation_tag(std::move(confirmation_tag_in))
, signer(0)
{
grease(extensions);
}
struct GroupInfoTBS
{
GroupContext group_context;
ExtensionList extensions;
bytes confirmation_tag;
LeafIndex signer;
TLS_SERIALIZABLE(group_context, extensions, confirmation_tag, signer)
};
bytes
GroupInfo::to_be_signed() const
{
return tls::marshal(
GroupInfoTBS{ group_context, extensions, confirmation_tag, signer });
}
void
GroupInfo::sign(const TreeKEMPublicKey& tree,
LeafIndex signer_index,
const SignaturePrivateKey& priv)
{
auto maybe_leaf = tree.leaf_node(signer_index);
if (!maybe_leaf) {
throw InvalidParameterError("Cannot sign from a blank leaf");
}
if (priv.public_key != opt::get(maybe_leaf).signature_key) {
throw InvalidParameterError("Bad key for index");
}
signer = signer_index;
signature = priv.sign(tree.suite, sign_label::group_info, to_be_signed());
}
bool
GroupInfo::verify(const TreeKEMPublicKey& tree) const
{
auto maybe_leaf = tree.leaf_node(signer);
if (!maybe_leaf) {
throw InvalidParameterError("Signer not found");
}
const auto& leaf = opt::get(maybe_leaf);
return verify(leaf.signature_key);
}
void
GroupInfo::sign(LeafIndex signer_index, const SignaturePrivateKey& priv)
{
signer = signer_index;
signature = priv.sign(
group_context.cipher_suite, sign_label::group_info, to_be_signed());
}
bool
GroupInfo::verify(const SignaturePublicKey& pub) const
{
return pub.verify(group_context.cipher_suite,
sign_label::group_info,
to_be_signed(),
signature);
}
// Welcome
Welcome::Welcome()
: cipher_suite(CipherSuite::ID::unknown)
{
}
Welcome::Welcome(CipherSuite suite,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks,
const GroupInfo& group_info)
: cipher_suite(suite)
, _joiner_secret(joiner_secret)
{
// Cache the list of PSK IDs
for (const auto& psk : psks) {
_psks.psks.push_back(psk.id);
}
// Pre-encrypt the GroupInfo
auto [key, nonce] = group_info_key_nonce(suite, joiner_secret, psks);
auto group_info_data = tls::marshal(group_info);
encrypted_group_info =
cipher_suite.hpke().aead.seal(key, nonce, {}, group_info_data);
}
std::optional<int>
Welcome::find(const KeyPackage& kp) const
{
auto ref = kp.ref();
for (size_t i = 0; i < secrets.size(); i++) {
if (ref == secrets[i].new_member) {
return static_cast<int>(i);
}
}
return std::nullopt;
}
void
Welcome::encrypt(const KeyPackage& kp, const std::optional<bytes>& path_secret)
{
auto gs = GroupSecrets{ _joiner_secret, std::nullopt, _psks };
if (path_secret) {
gs.path_secret = GroupSecrets::PathSecret{ opt::get(path_secret) };
}
auto gs_data = tls::marshal(gs);
auto enc_gs = kp.init_key.encrypt(
kp.cipher_suite, encrypt_label::welcome, encrypted_group_info, gs_data);
secrets.push_back({ kp.ref(), enc_gs });
}
GroupSecrets
Welcome::decrypt_secrets(int kp_index, const HPKEPrivateKey& init_priv) const
{
auto secrets_data =
init_priv.decrypt(cipher_suite,
encrypt_label::welcome,
encrypted_group_info,
secrets.at(kp_index).encrypted_group_secrets);
return tls::get<GroupSecrets>(secrets_data);
}
GroupInfo
Welcome::decrypt(const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks) const
{
auto [key, nonce] = group_info_key_nonce(cipher_suite, joiner_secret, psks);
auto group_info_data =
cipher_suite.hpke().aead.open(key, nonce, {}, encrypted_group_info);
if (!group_info_data) {
throw ProtocolError("Welcome decryption failed");
}
return tls::get<GroupInfo>(opt::get(group_info_data));
}
KeyAndNonce
Welcome::group_info_key_nonce(CipherSuite suite,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks)
{
auto welcome_secret =
KeyScheduleEpoch::welcome_secret(suite, joiner_secret, psks);
// XXX(RLB): These used to be done with ExpandWithLabel. Should we do that
// instead, for better domain separation? (In particular, including "mls10")
// That is what we do for the sender data key/nonce.
auto key =
suite.expand_with_label(welcome_secret, "key", {}, suite.key_size());
auto nonce =
suite.expand_with_label(welcome_secret, "nonce", {}, suite.nonce_size());
return { std::move(key), std::move(nonce) };
}
// Commit
std::optional<bytes>
Commit::valid_external() const
{
// External Commits MUST contain a path field (and is therefore a "full"
// Commit). The joiner is added at the leftmost free leaf node (just as if
// they were added with an Add proposal), and the path is calculated relative
// to that leaf node.
//
// The Commit MUST NOT include any proposals by reference, since an external
// joiner cannot determine the validity of proposals sent within the group
const auto all_by_value = stdx::all_of(proposals, [](const auto& p) {
return var::holds_alternative<Proposal>(p.content);
});
if (!path || !all_by_value) {
return std::nullopt;
}
const auto ext_init_ptr = stdx::find_if(proposals, [](const auto& p) {
const auto proposal = var::get<Proposal>(p.content);
return proposal.proposal_type() == ProposalType::external_init;
});
if (ext_init_ptr == proposals.end()) {
return std::nullopt;
}
const auto& ext_init_proposal = var::get<Proposal>(ext_init_ptr->content);
const auto& ext_init = var::get<ExternalInit>(ext_init_proposal.content);
return ext_init.kem_output;
}
// PublicMessage
Proposal::Type
Proposal::proposal_type() const
{
return tls::variant<ProposalType>::type(content).val;
}
SenderType
Sender::sender_type() const
{
return tls::variant<SenderType>::type(sender);
}
tls::ostream&
operator<<(tls::ostream& str, const GroupContentAuthData& obj)
{
switch (obj.content_type) {
case ContentType::proposal:
case ContentType::application:
return str << obj.signature;
case ContentType::commit:
return str << obj.signature << opt::get(obj.confirmation_tag);
default:
throw InvalidParameterError("Invalid content type");
}
}
tls::istream&
operator>>(tls::istream& str, GroupContentAuthData& obj)
{
switch (obj.content_type) {
case ContentType::proposal:
case ContentType::application:
return str >> obj.signature;
case ContentType::commit:
obj.confirmation_tag.emplace();
return str >> obj.signature >> opt::get(obj.confirmation_tag);
default:
throw InvalidParameterError("Invalid content type");
}
}
bool
operator==(const GroupContentAuthData& lhs, const GroupContentAuthData& rhs)
{
return lhs.content_type == rhs.content_type &&
lhs.signature == rhs.signature &&
lhs.confirmation_tag == rhs.confirmation_tag;
}
GroupContent::GroupContent(bytes group_id_in,
epoch_t epoch_in,
Sender sender_in,
bytes authenticated_data_in,
RawContent content_in)
: group_id(std::move(group_id_in))
, epoch(epoch_in)
, sender(sender_in)
, authenticated_data(std::move(authenticated_data_in))
, content(std::move(content_in))
{
}
GroupContent::GroupContent(bytes group_id_in,
epoch_t epoch_in,
Sender sender_in,
bytes authenticated_data_in,
ContentType content_type)
: group_id(std::move(group_id_in))
, epoch(epoch_in)
, sender(sender_in)
, authenticated_data(std::move(authenticated_data_in))
{
switch (content_type) {
case ContentType::commit:
content.emplace<Commit>();
break;
case ContentType::proposal:
content.emplace<Proposal>();
break;
case ContentType::application:
content.emplace<ApplicationData>();
break;
default:
throw InvalidParameterError("Invalid content type");
}
}
ContentType
GroupContent::content_type() const
{
return tls::variant<ContentType>::type(content);
}
AuthenticatedContent
AuthenticatedContent::sign(WireFormat wire_format,
GroupContent content,
CipherSuite suite,
const SignaturePrivateKey& sig_priv,
const std::optional<GroupContext>& context)
{
if (wire_format == WireFormat::mls_public_message &&
content.content_type() == ContentType::application) {
throw InvalidParameterError(
"Application data cannot be sent as PublicMessage");
}
auto content_auth = AuthenticatedContent{ wire_format, std::move(content) };
auto tbs = content_auth.to_be_signed(context);
content_auth.auth.signature =
sig_priv.sign(suite, sign_label::mls_content, tbs);
return content_auth;
}
bool
AuthenticatedContent::verify(CipherSuite suite,
const SignaturePublicKey& sig_pub,
const std::optional<GroupContext>& context) const
{
if (wire_format == WireFormat::mls_public_message &&
content.content_type() == ContentType::application) {
return false;
}
auto tbs = to_be_signed(context);
return sig_pub.verify(suite, sign_label::mls_content, tbs, auth.signature);
}
struct ConfirmedTranscriptHashInput
{
WireFormat wire_format;
const GroupContent& content;
const bytes& signature;
TLS_SERIALIZABLE(wire_format, content, signature);
};
struct InterimTranscriptHashInput
{
const bytes& confirmation_tag;
TLS_SERIALIZABLE(confirmation_tag);
};
bytes
AuthenticatedContent::confirmed_transcript_hash_input() const
{
return tls::marshal(ConfirmedTranscriptHashInput{
wire_format,
content,
auth.signature,
});
}
bytes
AuthenticatedContent::interim_transcript_hash_input() const
{
return tls::marshal(
InterimTranscriptHashInput{ opt::get(auth.confirmation_tag) });
}
void
AuthenticatedContent::set_confirmation_tag(const bytes& confirmation_tag)
{
auth.confirmation_tag = confirmation_tag;
}
bool
AuthenticatedContent::check_confirmation_tag(
const bytes& confirmation_tag) const
{
return confirmation_tag == opt::get(auth.confirmation_tag);
}
tls::ostream&
operator<<(tls::ostream& str, const AuthenticatedContent& obj)
{
return str << obj.wire_format << obj.content << obj.auth;
}
tls::istream&
operator>>(tls::istream& str, AuthenticatedContent& obj)
{
str >> obj.wire_format >> obj.content;
obj.auth.content_type = obj.content.content_type();
return str >> obj.auth;
}
bool
operator==(const AuthenticatedContent& lhs, const AuthenticatedContent& rhs)
{
return lhs.wire_format == rhs.wire_format && lhs.content == rhs.content &&
lhs.auth == rhs.auth;
}
AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in,
GroupContent content_in)
: wire_format(wire_format_in)
, content(std::move(content_in))
{
auth.content_type = content.content_type();
}
AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in,
GroupContent content_in,
GroupContentAuthData auth_in)
: wire_format(wire_format_in)
, content(std::move(content_in))
, auth(std::move(auth_in))
{
}
const AuthenticatedContent&
ValidatedContent::authenticated_content() const
{
return content_auth;
}
ValidatedContent::ValidatedContent(AuthenticatedContent content_auth_in)
: content_auth(std::move(content_auth_in))
{
}
bool
operator==(const ValidatedContent& lhs, const ValidatedContent& rhs)
{
return lhs.content_auth == rhs.content_auth;
}
struct GroupContentTBS
{
WireFormat wire_format = WireFormat::reserved;
const GroupContent& content;
const std::optional<GroupContext>& context;
};
static tls::ostream&
operator<<(tls::ostream& str, const GroupContentTBS& obj)
{
str << ProtocolVersion::mls10 << obj.wire_format << obj.content;
switch (obj.content.sender.sender_type()) {
case SenderType::member:
case SenderType::new_member_commit:
str << opt::get(obj.context);
break;
case SenderType::external:
case SenderType::new_member_proposal:
break;
default:
throw InvalidParameterError("Invalid sender type");
}
return str;
}
bytes
AuthenticatedContent::to_be_signed(
const std::optional<GroupContext>& context) const
{
return tls::marshal(GroupContentTBS{
wire_format,
content,
context,
});
}
PublicMessage
PublicMessage::protect(AuthenticatedContent content_auth,
CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context)
{
auto pt = PublicMessage(std::move(content_auth));
// Add the membership_mac if required
switch (pt.content.sender.sender_type()) {
case SenderType::member:
pt.membership_tag =
pt.membership_mac(suite, opt::get(membership_key), context);
break;
default:
break;
}
return pt;
}
std::optional<ValidatedContent>
PublicMessage::unprotect(CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context) const
{
// Verify the membership_tag if the message was sent within the group
switch (content.sender.sender_type()) {
case SenderType::member: {
auto candidate = membership_mac(suite, opt::get(membership_key), context);
if (candidate != opt::get(membership_tag)) {
return std::nullopt;
}
break;
}
default:
break;
}
return { { AuthenticatedContent{
WireFormat::mls_public_message,
content,
auth,
} } };
}
bool
PublicMessage::contains(const AuthenticatedContent& content_auth) const
{
return content == content_auth.content && auth == content_auth.auth;
}
AuthenticatedContent
PublicMessage::authenticated_content() const
{
auto auth_content = AuthenticatedContent{};
auth_content.wire_format = WireFormat::mls_public_message;
auth_content.content = content;
auth_content.auth = auth;
return auth_content;
}
PublicMessage::PublicMessage(AuthenticatedContent content_auth)
: content(std::move(content_auth.content))
, auth(std::move(content_auth.auth))
{
if (content_auth.wire_format != WireFormat::mls_public_message) {
throw InvalidParameterError("Wire format mismatch (not mls_plaintext)");
}
}
struct GroupContentTBM
{
GroupContentTBS content_tbs;
GroupContentAuthData auth;
TLS_SERIALIZABLE(content_tbs, auth);
};
bytes
PublicMessage::membership_mac(CipherSuite suite,
const bytes& membership_key,
const std::optional<GroupContext>& context) const
{
auto tbm = tls::marshal(GroupContentTBM{
{ WireFormat::mls_public_message, content, context },
auth,
});
return suite.digest().hmac(membership_key, tbm);
}
tls::ostream&
operator<<(tls::ostream& str, const PublicMessage& obj)
{
switch (obj.content.sender.sender_type()) {
case SenderType::member:
return str << obj.content << obj.auth << opt::get(obj.membership_tag);
case SenderType::external:
case SenderType::new_member_proposal:
case SenderType::new_member_commit:
return str << obj.content << obj.auth;
default:
throw InvalidParameterError("Invalid sender type");
}
}
tls::istream&
operator>>(tls::istream& str, PublicMessage& obj)
{
str >> obj.content;
obj.auth.content_type = obj.content.content_type();
str >> obj.auth;
if (obj.content.sender.sender_type() == SenderType::member) {
obj.membership_tag.emplace();
str >> opt::get(obj.membership_tag);
}
return str;
}
bool
operator==(const PublicMessage& lhs, const PublicMessage& rhs)
{
return lhs.content == rhs.content && lhs.auth == rhs.auth &&
lhs.membership_tag == rhs.membership_tag;
}
bool
operator!=(const PublicMessage& lhs, const PublicMessage& rhs)
{
return !(lhs == rhs);
}
static bytes
marshal_ciphertext_content(const GroupContent& content,
const GroupContentAuthData& auth,
size_t padding_size)
{
auto w = tls::ostream{};
var::visit([&w](const auto& val) { w << val; }, content.content);
w << auth;
w.write_raw(bytes(padding_size, 0));
return w.bytes();
}
static void
unmarshal_ciphertext_content(const bytes& content_pt,
GroupContent& content,
GroupContentAuthData& auth)
{
auto r = tls::istream(content_pt);
var::visit([&r](auto& val) { r >> val; }, content.content);
r >> auth;
const auto padding = r.bytes();
const auto nonzero = [](const auto& x) { return x != 0; };
if (stdx::any_of(padding, nonzero)) {
throw ProtocolError("Malformed AuthenticatedContentTBE padding");
}
}
struct ContentAAD
{
const bytes& group_id;
const epoch_t epoch;
const ContentType content_type;
const bytes& authenticated_data;
TLS_SERIALIZABLE(group_id, epoch, content_type, authenticated_data)
};
struct SenderData
{
LeafIndex sender{ 0 };
uint32_t generation{ 0 };
ReuseGuard reuse_guard{ 0, 0, 0, 0 };
TLS_SERIALIZABLE(sender, generation, reuse_guard)
};
struct SenderDataAAD
{
const bytes& group_id;
const epoch_t epoch;
const ContentType content_type;
TLS_SERIALIZABLE(group_id, epoch, content_type)
};
PrivateMessage
PrivateMessage::protect(AuthenticatedContent content_auth,
CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret,
size_t padding_size)
{
// Pull keys from the secret tree
auto index =
var::get<MemberSender>(content_auth.content.sender.sender).sender;
auto content_type = content_auth.content.content_type();
auto [generation, reuse_guard, content_keys] = keys.next(content_type, index);
// Encrypt the content
auto content_pt = marshal_ciphertext_content(
content_auth.content, content_auth.auth, padding_size);
auto content_aad = tls::marshal(ContentAAD{
content_auth.content.group_id,
content_auth.content.epoch,
content_auth.content.content_type(),
content_auth.content.authenticated_data,
});
auto content_ct = suite.hpke().aead.seal(
content_keys.key, content_keys.nonce, content_aad, content_pt);
// Encrypt the sender data
auto sender_index =
var::get<MemberSender>(content_auth.content.sender.sender).sender;
auto sender_data_pt = tls::marshal(SenderData{
sender_index,
generation,
reuse_guard,
});
auto sender_data_aad = tls::marshal(SenderDataAAD{
content_auth.content.group_id,
content_auth.content.epoch,
content_auth.content.content_type(),
});
auto sender_data_keys =
KeyScheduleEpoch::sender_data_keys(suite, sender_data_secret, content_ct);
auto sender_data_ct = suite.hpke().aead.seal(sender_data_keys.key,
sender_data_keys.nonce,
sender_data_aad,
sender_data_pt);
return PrivateMessage{
std::move(content_auth.content),
std::move(sender_data_ct),
std::move(content_ct),
};
}
std::optional<ValidatedContent>
PrivateMessage::unprotect(CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret) const
{
// Decrypt and parse the sender data
auto sender_data_keys =
KeyScheduleEpoch::sender_data_keys(suite, sender_data_secret, ciphertext);
auto sender_data_aad = tls::marshal(SenderDataAAD{
group_id,
epoch,
content_type,
});
auto sender_data_pt = suite.hpke().aead.open(sender_data_keys.key,
sender_data_keys.nonce,
sender_data_aad,
encrypted_sender_data);
if (!sender_data_pt) {
return std::nullopt;
}
auto sender_data = tls::get<SenderData>(opt::get(sender_data_pt));
if (!keys.has_leaf(sender_data.sender)) {
return std::nullopt;
}
// Decrypt the content
auto content_keys = keys.get(content_type,
sender_data.sender,
sender_data.generation,
sender_data.reuse_guard);
keys.erase(content_type, sender_data.sender, sender_data.generation);
auto content_aad = tls::marshal(ContentAAD{
group_id,
epoch,
content_type,
authenticated_data,
});
auto content_pt = suite.hpke().aead.open(
content_keys.key, content_keys.nonce, content_aad, ciphertext);
if (!content_pt) {
return std::nullopt;
}
// Parse the content
auto content = GroupContent{ group_id,
epoch,
{ MemberSender{ sender_data.sender } },
authenticated_data,
content_type };
auto auth = GroupContentAuthData{ content_type, {}, {} };
unmarshal_ciphertext_content(opt::get(content_pt), content, auth);
return { { AuthenticatedContent{
WireFormat::mls_private_message,
std::move(content),
std::move(auth),
} } };
}
PrivateMessage::PrivateMessage(GroupContent content,
bytes encrypted_sender_data_in,
bytes ciphertext_in)
: group_id(std::move(content.group_id))
, epoch(content.epoch)
, content_type(content.content_type())
, authenticated_data(std::move(content.authenticated_data))
, encrypted_sender_data(std::move(encrypted_sender_data_in))
, ciphertext(std::move(ciphertext_in))
{
}
bytes
MLSMessage::group_id() const
{
return var::visit(
overloaded{
[](const PublicMessage& pt) -> bytes { return pt.get_group_id(); },
[](const PrivateMessage& ct) -> bytes { return ct.get_group_id(); },
[](const GroupInfo& gi) -> bytes { return gi.group_context.group_id; },
[](const auto& /* unused */) -> bytes {
throw InvalidParameterError("MLSMessage has no group_id");
},
},
message);
}
epoch_t
MLSMessage::epoch() const
{
return var::visit(
overloaded{
[](const PublicMessage& pt) -> epoch_t { return pt.get_epoch(); },
[](const PrivateMessage& pt) -> epoch_t { return pt.get_epoch(); },
[](const auto& /* unused */) -> epoch_t {
throw InvalidParameterError("MLSMessage has no epoch");
},
},
message);
}
WireFormat
MLSMessage::wire_format() const
{
return tls::variant<WireFormat>::type(message);
}
MLSMessage::MLSMessage(PublicMessage public_message)
: message(std::move(public_message))
{
}
MLSMessage::MLSMessage(PrivateMessage private_message)
: message(std::move(private_message))
{
}
MLSMessage::MLSMessage(Welcome welcome)
: message(std::move(welcome))
{
}
MLSMessage::MLSMessage(GroupInfo group_info)
: message(std::move(group_info))
{
}
MLSMessage::MLSMessage(KeyPackage key_package)
: message(std::move(key_package))
{
}
MLSMessage
external_proposal(CipherSuite suite,
const bytes& group_id,
epoch_t epoch,
const Proposal& proposal,
uint32_t signer_index,
const SignaturePrivateKey& sig_priv)
{
switch (proposal.proposal_type()) {
// These proposal types are OK
case ProposalType::add:
case ProposalType::remove:
case ProposalType::psk:
case ProposalType::reinit:
case ProposalType::group_context_extensions:
break;
// These proposal types are forbidden
case ProposalType::invalid:
case ProposalType::update:
case ProposalType::external_init:
default:
throw ProtocolError("External proposal has invalid type");
}
auto content = GroupContent{ group_id,
epoch,
{ ExternalSenderIndex{ signer_index } },
{ /* no authenticated data */ },
{ proposal } };
auto content_auth = AuthenticatedContent::sign(
WireFormat::mls_public_message, std::move(content), suite, sig_priv, {});
return PublicMessage::protect(std::move(content_auth), suite, {}, {});
}
} // namespace mlspp

437
DPP/mlspp/src/session.cpp Executable file
View File

@@ -0,0 +1,437 @@
#include <mls/messages.h>
#include <mls/session.h>
#include <deque>
namespace mlspp {
///
/// Inner struct declarations for PendingJoin and Session
///
struct PendingJoin::Inner
{
const CipherSuite suite;
const HPKEPrivateKey init_priv;
const HPKEPrivateKey leaf_priv;
const SignaturePrivateKey sig_priv;
const KeyPackage key_package;
Inner(CipherSuite suite_in,
SignaturePrivateKey sig_priv_in,
Credential cred_in);
static PendingJoin create(CipherSuite suite,
SignaturePrivateKey sig_priv,
Credential cred);
};
struct Session::Inner
{
std::deque<State> history;
std::map<bytes, State> outbound_cache;
bool encrypt_handshake{ false };
explicit Inner(State state);
static Session begin(CipherSuite suite,
const bytes& group_id,
const HPKEPrivateKey& leaf_priv,
const SignaturePrivateKey& sig_priv,
const LeafNode& leaf_node);
static Session join(const HPKEPrivateKey& init_priv,
const HPKEPrivateKey& leaf_priv,
const SignaturePrivateKey& sig_priv,
const KeyPackage& key_package,
const bytes& welcome_data);
bytes fresh_secret() const;
MLSMessage import_handshake(const bytes& encoded) const;
State& for_epoch(epoch_t epoch);
};
///
/// Client
///
Client::Client(CipherSuite suite_in,
SignaturePrivateKey sig_priv_in,
Credential cred_in)
: suite(suite_in)
, sig_priv(std::move(sig_priv_in))
, cred(std::move(cred_in))
{
}
Session
Client::begin_session(const bytes& group_id) const
{
auto leaf_priv = HPKEPrivateKey::generate(suite);
auto leaf_node = LeafNode(suite,
leaf_priv.public_key,
sig_priv.public_key,
cred,
Capabilities::create_default(),
Lifetime::create_default(),
{},
sig_priv);
return Session::Inner::begin(suite, group_id, leaf_priv, sig_priv, leaf_node);
}
PendingJoin
Client::start_join() const
{
return PendingJoin::Inner::create(suite, sig_priv, cred);
}
///
/// PendingJoin
///
PendingJoin::Inner::Inner(CipherSuite suite_in,
SignaturePrivateKey sig_priv_in,
Credential cred_in)
: suite(suite_in)
, init_priv(HPKEPrivateKey::generate(suite))
, leaf_priv(HPKEPrivateKey::generate(suite))
, sig_priv(std::move(sig_priv_in))
, key_package(suite,
init_priv.public_key,
LeafNode(suite,
leaf_priv.public_key,
sig_priv.public_key,
std::move(cred_in),
Capabilities::create_default(),
Lifetime::create_default(),
{},
sig_priv),
{},
sig_priv)
{
}
PendingJoin
PendingJoin::Inner::create(CipherSuite suite,
SignaturePrivateKey sig_priv,
Credential cred)
{
auto inner =
std::make_unique<Inner>(suite, std::move(sig_priv), std::move(cred));
return { inner.release() };
}
PendingJoin::PendingJoin(PendingJoin&& other) noexcept = default;
PendingJoin&
PendingJoin::operator=(PendingJoin&& other) noexcept = default;
PendingJoin::~PendingJoin() = default;
PendingJoin::PendingJoin(Inner* inner_in)
: inner(inner_in)
{
}
bytes
PendingJoin::key_package() const
{
return tls::marshal(inner->key_package);
}
Session
PendingJoin::complete(const bytes& welcome) const
{
return Session::Inner::join(inner->init_priv,
inner->leaf_priv,
inner->sig_priv,
inner->key_package,
welcome);
}
///
/// Session
///
Session::Inner::Inner(State state)
: history{ std::move(state) }
, encrypt_handshake(true)
{
}
Session
Session::Inner::begin(CipherSuite suite,
const bytes& group_id,
const HPKEPrivateKey& leaf_priv,
const SignaturePrivateKey& sig_priv,
const LeafNode& leaf_node)
{
auto state = State(group_id, suite, leaf_priv, sig_priv, leaf_node, {});
auto inner = std::make_unique<Inner>(state);
return { inner.release() };
}
Session
Session::Inner::join(const HPKEPrivateKey& init_priv,
const HPKEPrivateKey& leaf_priv,
const SignaturePrivateKey& sig_priv,
const KeyPackage& key_package,
const bytes& welcome_data)
{
auto welcome = tls::get<Welcome>(welcome_data);
auto state = State(
init_priv, leaf_priv, sig_priv, key_package, welcome, std::nullopt, {});
auto inner = std::make_unique<Inner>(state);
return { inner.release() };
}
bytes
Session::Inner::fresh_secret() const
{
const auto suite = history.front().cipher_suite();
return random_bytes(suite.secret_size());
}
MLSMessage
Session::Inner::import_handshake(const bytes& encoded) const
{
auto msg = tls::get<MLSMessage>(encoded);
switch (msg.wire_format()) {
case WireFormat::mls_public_message:
if (encrypt_handshake) {
throw ProtocolError("Handshake not encrypted as required");
}
return msg;
case WireFormat::mls_private_message: {
if (!encrypt_handshake) {
throw ProtocolError("Unexpected handshake encryption");
}
return msg;
}
default:
throw InvalidParameterError("Illegal wire format");
}
}
State&
Session::Inner::for_epoch(epoch_t epoch)
{
for (auto& state : history) {
if (state.epoch() == epoch) {
return state;
}
}
throw MissingStateError("No state for epoch");
}
Session::Session(Session&& other) noexcept = default;
Session&
Session::operator=(Session&& other) noexcept = default;
Session::~Session() = default;
Session::Session(Inner* inner_in)
: inner(inner_in)
{
}
void
Session::encrypt_handshake(bool enabled)
{
inner->encrypt_handshake = enabled;
}
bytes
Session::add(const bytes& key_package_data)
{
auto key_package = tls::get<KeyPackage>(key_package_data);
auto proposal = inner->history.front().add(
key_package, { inner->encrypt_handshake, {}, 0 });
return tls::marshal(proposal);
}
bytes
Session::update()
{
auto leaf_secret = inner->fresh_secret();
auto leaf_priv = HPKEPrivateKey::generate(cipher_suite());
auto proposal = inner->history.front().update(
std::move(leaf_priv), {}, { inner->encrypt_handshake, {}, 0 });
return tls::marshal(proposal);
}
bytes
Session::remove(uint32_t index)
{
auto proposal = inner->history.front().remove(
RosterIndex{ index }, { inner->encrypt_handshake, {}, 0 });
return tls::marshal(proposal);
}
std::tuple<bytes, bytes>
Session::commit(const bytes& proposal)
{
return commit(std::vector<bytes>{ proposal });
}
std::tuple<bytes, bytes>
Session::commit(const std::vector<bytes>& proposals)
{
auto provisional_state = inner->history.front();
for (const auto& proposal_data : proposals) {
auto msg = inner->import_handshake(proposal_data);
auto maybe_state = provisional_state.handle(msg);
if (maybe_state) {
throw InvalidParameterError("Invalid proposal; actually a commit");
}
}
inner->history.front() = std::move(provisional_state);
return commit();
}
std::tuple<bytes, bytes>
Session::commit()
{
auto commit_secret = inner->fresh_secret();
auto encrypt = inner->encrypt_handshake;
auto [commit, welcome, new_state] = inner->history.front().commit(
commit_secret, CommitOpts{ {}, true, encrypt, {} }, { encrypt, {}, 0 });
auto commit_msg = tls::marshal(commit);
auto welcome_msg = tls::marshal(welcome);
inner->outbound_cache.insert({ commit_msg, new_state });
return std::make_tuple(welcome_msg, commit_msg);
}
bool
Session::handle(const bytes& handshake_data)
{
auto msg = inner->import_handshake(handshake_data);
auto maybe_cached_state = std::optional<State>{};
auto node = inner->outbound_cache.extract(handshake_data);
if (!node.empty()) {
maybe_cached_state = node.mapped();
}
auto maybe_next_state =
inner->history.front().handle(msg, maybe_cached_state);
if (!maybe_next_state) {
return false;
}
inner->history.emplace_front(opt::get(maybe_next_state));
return true;
}
epoch_t
Session::epoch() const
{
return inner->history.front().epoch();
}
LeafIndex
Session::index() const
{
return inner->history.front().index();
}
CipherSuite
Session::cipher_suite() const
{
return inner->history.front().cipher_suite();
}
const ExtensionList&
Session::extensions() const
{
return inner->history.front().extensions();
}
const TreeKEMPublicKey&
Session::tree() const
{
return inner->history.front().tree();
}
bytes
Session::do_export(const std::string& label,
const bytes& context,
size_t size) const
{
return inner->history.front().do_export(label, context, size);
}
GroupInfo
Session::group_info() const
{
return inner->history.front().group_info(true);
}
std::vector<LeafNode>
Session::roster() const
{
return inner->history.front().roster();
}
bytes
Session::epoch_authenticator() const
{
return inner->history.front().epoch_authenticator();
}
bytes
Session::protect(const bytes& plaintext)
{
auto msg = inner->history.front().protect({}, plaintext, 0);
return tls::marshal(msg);
}
// TODO(rlb@ipv.sx): It would be good to expose identity information
// here, since ciphertexts are authenticated per sender. Who sent
// this ciphertext?
bytes
Session::unprotect(const bytes& ciphertext)
{
auto ciphertext_obj = tls::get<MLSMessage>(ciphertext);
auto& state = inner->for_epoch(ciphertext_obj.epoch());
auto [aad, pt] = state.unprotect(ciphertext_obj);
silence_unused(aad);
return pt;
}
bool
operator==(const Session& lhs, const Session& rhs)
{
if (lhs.inner->encrypt_handshake != rhs.inner->encrypt_handshake) {
return false;
}
auto size = std::min(lhs.inner->history.size(), rhs.inner->history.size());
for (size_t i = 0; i < size; i += 1) {
if (lhs.inner->history.at(i) != rhs.inner->history.at(i)) {
return false;
}
}
return true;
}
bool
operator!=(const Session& lhs, const Session& rhs)
{
return !(lhs == rhs);
}
} // namespace mlspp

2219
DPP/mlspp/src/state.cpp Executable file

File diff suppressed because it is too large Load Diff

223
DPP/mlspp/src/tree_math.cpp Executable file
View File

@@ -0,0 +1,223 @@
#include "mls/tree_math.h"
#include "mls/common.h"
#include <algorithm>
static const uint32_t one = 0x01;
static uint32_t
log2(uint32_t x)
{
if (x == 0) {
return 0;
}
uint32_t k = 0;
while ((x >> k) > 0) {
k += 1;
}
return k - 1;
}
namespace mlspp {
LeafCount::LeafCount(const NodeCount w)
{
if (w.val == 0) {
val = 0;
return;
}
if ((w.val & one) == 0) {
throw InvalidParameterError("Only odd node counts describe trees");
}
val = (w.val >> one) + 1;
}
LeafCount
LeafCount::full(const LeafCount n)
{
auto w = uint32_t(1);
while (w < n.val) {
w <<= 1U;
}
return LeafCount{ w };
}
NodeCount::NodeCount(const LeafCount n)
: UInt32(2 * (n.val - 1) + 1)
{
}
LeafIndex::LeafIndex(NodeIndex x)
: UInt32(0)
{
if (x.val % 2 == 1) {
throw InvalidParameterError("Only even node indices describe leaves");
}
val = x.val >> 1; // NOLINT(hicpp-signed-bitwise)
}
NodeIndex
LeafIndex::ancestor(LeafIndex other) const
{
auto ln = NodeIndex(*this);
auto rn = NodeIndex(other);
if (ln == rn) {
return ln;
}
uint8_t k = 0;
while (ln != rn) {
ln.val = ln.val >> 1U;
rn.val = rn.val >> 1U;
k += 1;
}
const uint32_t prefix = ln.val << k;
const uint32_t stop = (1U << uint8_t(k - 1));
return NodeIndex{ prefix + (stop - 1) };
}
NodeIndex::NodeIndex(LeafIndex x)
: UInt32(2 * x.val)
{
}
NodeIndex
NodeIndex::root(LeafCount n)
{
if (n.val == 0) {
throw std::runtime_error("Root for zero-size tree is undefined");
}
auto w = NodeCount(n);
return NodeIndex{ (one << log2(w.val)) - 1 };
}
bool
NodeIndex::is_leaf() const
{
return val % 2 == 0;
}
bool
NodeIndex::is_below(NodeIndex other) const
{
auto lx = level();
auto ly = other.level();
return lx <= ly && (val >> (ly + 1) == other.val >> (ly + 1));
}
NodeIndex
NodeIndex::left() const
{
if (is_leaf()) {
return *this;
}
// The clang analyzer doesn't realize that is_leaf() assures that level >= 1
// NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
return NodeIndex{ val ^ (one << (level() - 1)) };
}
NodeIndex
NodeIndex::right() const
{
if (is_leaf()) {
return *this;
}
return NodeIndex{ val ^ (uint32_t(0x03) << (level() - 1)) };
}
NodeIndex
NodeIndex::parent() const
{
auto k = level();
return NodeIndex{ (val | (one << k)) & ~(one << (k + 1)) };
}
NodeIndex
NodeIndex::sibling() const
{
return sibling(parent());
}
NodeIndex
NodeIndex::sibling(NodeIndex ancestor) const
{
if (!is_below(ancestor)) {
throw InvalidParameterError("Node is not below claimed ancestor");
}
auto l = ancestor.left();
auto r = ancestor.right();
if (is_below(l)) {
return r;
}
return l;
}
std::vector<NodeIndex>
NodeIndex::dirpath(LeafCount n)
{
if (val >= NodeCount(n).val) {
throw InvalidParameterError("Request for dirpath outside of tree");
}
auto d = std::vector<NodeIndex>{};
auto r = root(n);
if (*this == r) {
return d;
}
auto p = parent();
while (p.val != r.val) {
d.push_back(p);
p = p.parent();
}
// Include the root except in a one-member tree
if (val != r.val) {
d.push_back(p);
}
return d;
}
std::vector<NodeIndex>
NodeIndex::copath(LeafCount n)
{
auto d = dirpath(n);
if (d.empty()) {
return {};
}
// Prepend leaf; omit root
d.insert(d.begin(), *this);
d.pop_back();
return stdx::transform<NodeIndex>(d, [](auto x) { return x.sibling(); });
}
uint32_t
NodeIndex::level() const
{
if ((val & one) == 0) {
return 0;
}
uint32_t k = 0;
while (((val >> k) & one) == 1) {
k += 1;
}
return k;
}
} // namespace mlspp

1127
DPP/mlspp/src/treekem.cpp Executable file

File diff suppressed because it is too large Load Diff