add basic discord support
This commit is contained in:
117
DPP/mlspp/CMakeLists.txt
Executable file
117
DPP/mlspp/CMakeLists.txt
Executable file
@@ -0,0 +1,117 @@
|
||||
cmake_minimum_required(VERSION 3.13)
|
||||
|
||||
project(mlspp
|
||||
VERSION 0.1
|
||||
LANGUAGES CXX
|
||||
)
|
||||
|
||||
option(TESTING "Build tests" OFF)
|
||||
option(CLANG_TIDY "Perform linting with clang-tidy" OFF)
|
||||
option(SANITIZERS "Enable sanitizers" OFF)
|
||||
option(MLS_NAMESPACE_SUFFIX "Namespace Suffix for CXX and CMake Export")
|
||||
option(DISABLE_GREASE "Disables the inclusion of MLS protocol recommended GREASE values" ON)
|
||||
option(REQUIRE_BORINGSSL "Require BoringSSL instead of OpenSSL" OFF)
|
||||
|
||||
if(MLS_NAMESPACE_SUFFIX)
|
||||
set(MLS_CXX_NAMESPACE "mls_${MLS_NAMESPACE_SUFFIX}" CACHE STRING "Top-level Namespace for CXX")
|
||||
set(MLS_EXPORT_NAMESPACE "MLSPP${MLS_NAMESPACE_SUFFIX}" CACHE STRING "Namespace for CMake Export")
|
||||
else()
|
||||
set(MLS_CXX_NAMESPACE "../include/dpp/mlspp/mls" CACHE STRING "Top-level Namespace for CXX")
|
||||
set(MLS_EXPORT_NAMESPACE "MLSPP" CACHE STRING "Namespace for CMake Export")
|
||||
endif()
|
||||
message(STATUS "CXX Namespace: ${MLS_CXX_NAMESPACE}")
|
||||
message(STATUS "CMake Export Namespace: ${MLS_EXPORT_NAMESPACE}")
|
||||
|
||||
|
||||
###
|
||||
### Global Config
|
||||
###
|
||||
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
|
||||
|
||||
configure_file(
|
||||
"cmake/namespace.h.in"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/include/namespace.h"
|
||||
@ONLY
|
||||
)
|
||||
|
||||
include(CheckCXXCompilerFlag)
|
||||
include(CMakePackageConfigHelpers)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "GNU")
|
||||
add_compile_options(-Wall -fPIC)
|
||||
elseif(MSVC)
|
||||
add_compile_options(/W2)
|
||||
add_definitions(-DWINDOWS)
|
||||
|
||||
# MSVC helpfully recommends safer equivalents for things like
|
||||
# getenv, but they are not portable.
|
||||
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
|
||||
endif()
|
||||
|
||||
if("$ENV{MACOSX_DEPLOYMENT_TARGET}" STREQUAL "10.11")
|
||||
add_compile_options(-DVARIANT_COMPAT)
|
||||
endif()
|
||||
|
||||
add_compile_options(-DDISABLE_GREASE)
|
||||
|
||||
###
|
||||
### Dependencies
|
||||
###
|
||||
|
||||
# Configure vcpkg to only build release libraries
|
||||
set(VCPKG_BUILD_TYPE release)
|
||||
|
||||
if (${OPENSSL_VERSION} VERSION_GREATER_EQUAL 3)
|
||||
add_compile_definitions(WITH_OPENSSL3)
|
||||
elseif(${OPENSSL_VERSION} VERSION_LESS 1.1.1)
|
||||
message(FATAL_ERROR "OpenSSL 1.1.1 or greater is required")
|
||||
endif()
|
||||
message(STATUS "OpenSSL Found: ${OPENSSL_VERSION}")
|
||||
message(STATUS "OpenSSL Include: ${OPENSSL_INCLUDE_DIR}")
|
||||
message(STATUS "OpenSSL Libraries: ${OPENSSL_LIBRARIES}")
|
||||
|
||||
# Internal libraries
|
||||
add_subdirectory(lib)
|
||||
|
||||
###
|
||||
### Library Config
|
||||
###
|
||||
|
||||
set(LIB_NAME "${PROJECT_NAME}")
|
||||
|
||||
file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
|
||||
file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
|
||||
|
||||
add_library(${LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES})
|
||||
add_dependencies(${LIB_NAME} bytes tls_syntax hpke)
|
||||
target_link_libraries(${LIB_NAME} bytes tls_syntax hpke)
|
||||
target_include_directories(${LIB_NAME}
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include/${PROJECT_NAME}>
|
||||
PRIVATE
|
||||
${OPENSSL_INCLUDE_DIR}
|
||||
)
|
||||
|
||||
###
|
||||
### Exports
|
||||
###
|
||||
set(CMAKE_EXPORT_PACKAGE_REGISTRY ON)
|
||||
export(PACKAGE ${MLS_EXPORT_NAMESPACE})
|
||||
|
||||
configure_package_config_file(cmake/config.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${MLS_EXPORT_NAMESPACE}Config.cmake
|
||||
INSTALL_DESTINATION ${CMAKE_INSTALL_DATADIR}/${MLS_EXPORT_NAMESPACE}
|
||||
NO_SET_AND_CHECK_MACRO)
|
||||
|
||||
write_basic_package_version_file(
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${MLS_EXPORT_NAMESPACE}ConfigVersion.cmake
|
||||
VERSION ${PROJECT_VERSION}
|
||||
COMPATIBILITY SameMajorVersion)
|
||||
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/bytes/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/hpke/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/mls_vectors/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib/tls_syntax/include")
|
||||
25
DPP/mlspp/LICENSE
Executable file
25
DPP/mlspp/LICENSE
Executable file
@@ -0,0 +1,25 @@
|
||||
BSD 2-Clause License
|
||||
|
||||
Copyright (c) 2018, Cisco Systems
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
4
DPP/mlspp/cmake/config.cmake.in
Executable file
4
DPP/mlspp/cmake/config.cmake.in
Executable file
@@ -0,0 +1,4 @@
|
||||
@PACKAGE_INIT@
|
||||
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/@MLS_EXPORT_NAMESPACE@Targets.cmake)
|
||||
check_required_components(mlspp)
|
||||
4
DPP/mlspp/cmake/namespace.h.in
Executable file
4
DPP/mlspp/cmake/namespace.h.in
Executable file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
// Configurable top-level MLS namespace
|
||||
#define MLS_NAMESPACE @MLS_CXX_NAMESPACE@
|
||||
274
DPP/mlspp/include/mls/common.h
Executable file
274
DPP/mlspp/include/mls/common.h
Executable file
@@ -0,0 +1,274 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <iomanip>
|
||||
#include <iterator>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace std::literals::string_literals;
|
||||
|
||||
// Expose the bytes library globally
|
||||
#include <bytes/bytes.h>
|
||||
using namespace mlspp::bytes_ns;
|
||||
|
||||
// Expose the compatibility library globally
|
||||
#include <tls/compat.h>
|
||||
namespace var = mlspp::tls::var;
|
||||
namespace opt = mlspp::tls::opt;
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
// Make variant equality work in the same way as optional equality, with
|
||||
// automatic unwrapping. In other words
|
||||
//
|
||||
// v == T(x) <=> hold_alternative<T>(v) && get<T>(v) == x
|
||||
//
|
||||
// For consistency, we also define symmetric and negated version. In this
|
||||
// house, we obey the symmetric law of equivalence relations!
|
||||
template<typename T, typename... Ts>
|
||||
bool
|
||||
operator==(const var::variant<Ts...>& v, const T& t)
|
||||
{
|
||||
return var::visit(
|
||||
[&](const auto& arg) {
|
||||
using U = std::decay_t<decltype(arg)>;
|
||||
if constexpr (std::is_same_v<U, T>) {
|
||||
return arg == t;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
v);
|
||||
}
|
||||
|
||||
template<typename T, typename... Ts>
|
||||
bool
|
||||
operator==(const T& t, const var::variant<Ts...>& v)
|
||||
{
|
||||
return v == t;
|
||||
}
|
||||
|
||||
template<typename T, typename... Ts>
|
||||
bool
|
||||
operator!=(const var::variant<Ts...>& v, const T& t)
|
||||
{
|
||||
return !(v == t);
|
||||
}
|
||||
|
||||
template<typename T, typename... Ts>
|
||||
bool
|
||||
operator!=(const T& t, const var::variant<Ts...>& v)
|
||||
{
|
||||
return !(v == t);
|
||||
}
|
||||
|
||||
using epoch_t = uint64_t;
|
||||
|
||||
///
|
||||
/// Get the current system clock time in the format MLS expects
|
||||
///
|
||||
|
||||
uint64_t
|
||||
seconds_since_epoch();
|
||||
|
||||
///
|
||||
/// Easy construction of overloaded lambdas
|
||||
///
|
||||
|
||||
template<class... Ts>
|
||||
struct overloaded : Ts...
|
||||
{
|
||||
using Ts::operator()...;
|
||||
|
||||
// XXX(RLB) MSVC has a bug where it incorrectly computes the size of this
|
||||
// type. Microsoft claims they have fixed it in the latest MSVC, and GitHub
|
||||
// claims they are running a version with the fix. But in practice, we still
|
||||
// hit it. Including this dummy variable is a work-around.
|
||||
//
|
||||
// https://developercommunity.visualstudio.com/t/runtime-stack-corruption-using-stdvisit/346200
|
||||
int dummy = 0;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
// XXX(RLB): For some reason, different versions of clang-format disagree on how
|
||||
// this should be formatted. Probably because it's new syntax with C++17?
|
||||
// Exempting it from clang-format for now.
|
||||
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
|
||||
// clang-format on
|
||||
|
||||
///
|
||||
/// Auto-generate equality and inequality operators for TLS-serializable things
|
||||
///
|
||||
|
||||
template<typename T>
|
||||
inline typename std::enable_if<T::_tls_serializable, bool>::type
|
||||
operator==(const T& lhs, const T& rhs)
|
||||
{
|
||||
return lhs._tls_fields_w() == rhs._tls_fields_w();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline typename std::enable_if<T::_tls_serializable, bool>::type
|
||||
operator!=(const T& lhs, const T& rhs)
|
||||
{
|
||||
return lhs._tls_fields_w() != rhs._tls_fields_w();
|
||||
}
|
||||
|
||||
///
|
||||
/// Error types
|
||||
///
|
||||
|
||||
// The `using parent = X` / `using parent::parent` construction here
|
||||
// imports the constructors of the parent.
|
||||
|
||||
class NotImplementedError : public std::exception
|
||||
{
|
||||
public:
|
||||
using parent = std::exception;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
class ProtocolError : public std::runtime_error
|
||||
{
|
||||
public:
|
||||
using parent = std::runtime_error;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
class IncompatibleNodesError : public std::invalid_argument
|
||||
{
|
||||
public:
|
||||
using parent = std::invalid_argument;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
class InvalidParameterError : public std::invalid_argument
|
||||
{
|
||||
public:
|
||||
using parent = std::invalid_argument;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
class InvalidPathError : public std::invalid_argument
|
||||
{
|
||||
public:
|
||||
using parent = std::invalid_argument;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
class InvalidIndexError : public std::invalid_argument
|
||||
{
|
||||
public:
|
||||
using parent = std::invalid_argument;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
class InvalidMessageTypeError : public std::invalid_argument
|
||||
{
|
||||
public:
|
||||
using parent = std::invalid_argument;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
class MissingNodeError : public std::out_of_range
|
||||
{
|
||||
public:
|
||||
using parent = std::out_of_range;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
class MissingStateError : public std::out_of_range
|
||||
{
|
||||
public:
|
||||
using parent = std::out_of_range;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
// A slightly more elegant way to silence -Werror=unused-variable
|
||||
template<typename T>
|
||||
void
|
||||
silence_unused(const T& val)
|
||||
{
|
||||
(void)val;
|
||||
}
|
||||
|
||||
namespace stdx {
|
||||
|
||||
// XXX(RLB) This method takes any container in, but always puts the resuls in
|
||||
// std::vector. The output could be made generic with a Rust-like syntax,
|
||||
// defining a PendingTransform object that caches the inputs, with a template
|
||||
// `collect()` method that puts them in an output container. Which makes the
|
||||
// calling syntax as follows:
|
||||
//
|
||||
// auto out = stdx::transform(in, f).collect<Container>();
|
||||
//
|
||||
// (You always need the explicit specialization, even if assigning it to an
|
||||
// explicitly typed variable, because C++ won't infer return types.)
|
||||
//
|
||||
// Given that the above syntax is pretty chatty, and we never need anything
|
||||
// other than vectors here anyway, I have left this as-is.
|
||||
template<typename Value, typename Container, typename UnaryOperation>
|
||||
std::vector<Value>
|
||||
transform(const Container& c, const UnaryOperation& op)
|
||||
{
|
||||
auto out = std::vector<Value>{};
|
||||
auto ins = std::inserter(out, out.begin());
|
||||
std::transform(c.begin(), c.end(), ins, op);
|
||||
return out;
|
||||
}
|
||||
|
||||
template<typename Container, typename UnaryPredicate>
|
||||
bool
|
||||
any_of(const Container& c, const UnaryPredicate& pred)
|
||||
{
|
||||
return std::any_of(c.begin(), c.end(), pred);
|
||||
}
|
||||
|
||||
template<typename Container, typename UnaryPredicate>
|
||||
bool
|
||||
all_of(const Container& c, const UnaryPredicate& pred)
|
||||
{
|
||||
return std::all_of(c.begin(), c.end(), pred);
|
||||
}
|
||||
|
||||
template<typename Container, typename UnaryPredicate>
|
||||
auto
|
||||
count_if(const Container& c, const UnaryPredicate& pred)
|
||||
{
|
||||
return std::count_if(c.begin(), c.end(), pred);
|
||||
}
|
||||
|
||||
template<typename Container, typename Value>
|
||||
bool
|
||||
contains(const Container& c, const Value& val)
|
||||
{
|
||||
return std::find(c.begin(), c.end(), val) != c.end();
|
||||
}
|
||||
|
||||
template<typename Container, typename UnaryPredicate>
|
||||
auto
|
||||
find_if(Container& c, const UnaryPredicate& pred)
|
||||
{
|
||||
return std::find_if(c.begin(), c.end(), pred);
|
||||
}
|
||||
|
||||
template<typename Container, typename UnaryPredicate>
|
||||
auto
|
||||
find_if(const Container& c, const UnaryPredicate& pred)
|
||||
{
|
||||
return std::find_if(c.begin(), c.end(), pred);
|
||||
}
|
||||
|
||||
template<typename Container, typename Value>
|
||||
auto
|
||||
upper_bound(const Container& c, const Value& val)
|
||||
{
|
||||
return std::upper_bound(c.begin(), c.end(), val);
|
||||
}
|
||||
|
||||
} // namespace stdx
|
||||
|
||||
} // namespace mlspp
|
||||
380
DPP/mlspp/include/mls/core_types.h
Executable file
380
DPP/mlspp/include/mls/core_types.h
Executable file
@@ -0,0 +1,380 @@
|
||||
#pragma once
|
||||
|
||||
#include "mls/credential.h"
|
||||
#include "mls/crypto.h"
|
||||
#include "mls/tree_math.h"
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
// enum {
|
||||
// reserved(0),
|
||||
// mls10(1),
|
||||
// (255)
|
||||
// } ProtocolVersion;
|
||||
enum class ProtocolVersion : uint16_t
|
||||
{
|
||||
mls10 = 0x01,
|
||||
};
|
||||
|
||||
extern const std::array<ProtocolVersion, 1> all_supported_versions;
|
||||
|
||||
// struct {
|
||||
// ExtensionType extension_type;
|
||||
// opaque extension_data<V>;
|
||||
// } Extension;
|
||||
struct Extension
|
||||
{
|
||||
using Type = uint16_t;
|
||||
|
||||
Type type;
|
||||
bytes data;
|
||||
|
||||
TLS_SERIALIZABLE(type, data)
|
||||
};
|
||||
|
||||
struct ExtensionType
|
||||
{
|
||||
static constexpr Extension::Type application_id = 1;
|
||||
static constexpr Extension::Type ratchet_tree = 2;
|
||||
static constexpr Extension::Type required_capabilities = 3;
|
||||
static constexpr Extension::Type external_pub = 4;
|
||||
static constexpr Extension::Type external_senders = 5;
|
||||
|
||||
// XXX(RLB) There is no IANA-registered type for this extension yet, so we use
|
||||
// a value from the vendor-specific space
|
||||
static constexpr Extension::Type sframe_parameters = 0xff02;
|
||||
};
|
||||
|
||||
struct ExtensionList
|
||||
{
|
||||
std::vector<Extension> extensions;
|
||||
|
||||
// XXX(RLB) It would be good if this maintained extensions in order. It might
|
||||
// be possible to do this automatically by changing the storage to a
|
||||
// map<ExtensionType, bytes> and extending the TLS code to marshal that type.
|
||||
template<typename T>
|
||||
inline void add(const T& obj)
|
||||
{
|
||||
auto data = tls::marshal(obj);
|
||||
add(T::type, std::move(data));
|
||||
}
|
||||
|
||||
void add(Extension::Type type, bytes data);
|
||||
|
||||
template<typename T>
|
||||
std::optional<T> find() const
|
||||
{
|
||||
for (const auto& ext : extensions) {
|
||||
if (ext.type == T::type) {
|
||||
return tls::get<T>(ext.data);
|
||||
}
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
bool has(uint16_t type) const;
|
||||
|
||||
TLS_SERIALIZABLE(extensions)
|
||||
};
|
||||
|
||||
// enum {
|
||||
// reserved(0),
|
||||
// key_package(1),
|
||||
// update(2),
|
||||
// commit(3),
|
||||
// (255)
|
||||
// } LeafNodeSource;
|
||||
enum struct LeafNodeSource : uint8_t
|
||||
{
|
||||
key_package = 1,
|
||||
update = 2,
|
||||
commit = 3,
|
||||
};
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion versions<V>;
|
||||
// CipherSuite ciphersuites<V>;
|
||||
// ExtensionType extensions<V>;
|
||||
// ProposalType proposals<V>;
|
||||
// CredentialType credentials<V>;
|
||||
// } Capabilities;
|
||||
struct Capabilities
|
||||
{
|
||||
std::vector<ProtocolVersion> versions;
|
||||
std::vector<CipherSuite::ID> cipher_suites;
|
||||
std::vector<Extension::Type> extensions;
|
||||
std::vector<uint16_t> proposals;
|
||||
std::vector<CredentialType> credentials;
|
||||
|
||||
static Capabilities create_default();
|
||||
bool extensions_supported(const std::vector<Extension::Type>& required) const;
|
||||
bool proposals_supported(const std::vector<uint16_t>& required) const;
|
||||
bool credential_supported(const Credential& credential) const;
|
||||
|
||||
template<typename Container>
|
||||
bool credentials_supported(const Container& required) const
|
||||
{
|
||||
return stdx::all_of(required, [&](CredentialType type) {
|
||||
return stdx::contains(credentials, type);
|
||||
});
|
||||
}
|
||||
|
||||
TLS_SERIALIZABLE(versions, cipher_suites, extensions, proposals, credentials)
|
||||
};
|
||||
|
||||
// struct {
|
||||
// uint64 not_before;
|
||||
// uint64 not_after;
|
||||
// } Lifetime;
|
||||
struct Lifetime
|
||||
{
|
||||
uint64_t not_before;
|
||||
uint64_t not_after;
|
||||
|
||||
static Lifetime create_default();
|
||||
|
||||
TLS_SERIALIZABLE(not_before, not_after)
|
||||
};
|
||||
|
||||
// struct {
|
||||
// HPKEPublicKey encryption_key;
|
||||
// SignaturePublicKey signature_key;
|
||||
// Credential credential;
|
||||
// Capabilities capabilities;
|
||||
//
|
||||
// LeafNodeSource leaf_node_source;
|
||||
// select (leaf_node_source) {
|
||||
// case add:
|
||||
// Lifetime lifetime;
|
||||
//
|
||||
// case update:
|
||||
// struct {}
|
||||
//
|
||||
// case commit:
|
||||
// opaque parent_hash<V>;
|
||||
// }
|
||||
//
|
||||
// Extension extensions<V>;
|
||||
// // SignWithLabel(., "LeafNodeTBS", LeafNodeTBS)
|
||||
// opaque signature<V>;
|
||||
// } LeafNode;
|
||||
struct Empty
|
||||
{
|
||||
TLS_SERIALIZABLE()
|
||||
};
|
||||
|
||||
struct ParentHash
|
||||
{
|
||||
bytes parent_hash;
|
||||
TLS_SERIALIZABLE(parent_hash);
|
||||
};
|
||||
|
||||
struct LeafNodeOptions
|
||||
{
|
||||
std::optional<Credential> credential;
|
||||
std::optional<Capabilities> capabilities;
|
||||
std::optional<ExtensionList> extensions;
|
||||
};
|
||||
|
||||
// TODO Move this to treekem.h
|
||||
struct LeafNode
|
||||
{
|
||||
HPKEPublicKey encryption_key;
|
||||
SignaturePublicKey signature_key;
|
||||
Credential credential;
|
||||
Capabilities capabilities;
|
||||
|
||||
var::variant<Lifetime, Empty, ParentHash> content;
|
||||
|
||||
ExtensionList extensions;
|
||||
bytes signature;
|
||||
|
||||
LeafNode() = default;
|
||||
LeafNode(const LeafNode&) = default;
|
||||
LeafNode(LeafNode&&) = default;
|
||||
LeafNode& operator=(const LeafNode&) = default;
|
||||
LeafNode& operator=(LeafNode&&) = default;
|
||||
|
||||
LeafNode(CipherSuite cipher_suite,
|
||||
HPKEPublicKey encryption_key_in,
|
||||
SignaturePublicKey signature_key_in,
|
||||
Credential credential_in,
|
||||
Capabilities capabilities_in,
|
||||
Lifetime lifetime_in,
|
||||
ExtensionList extensions_in,
|
||||
const SignaturePrivateKey& sig_priv);
|
||||
|
||||
LeafNode for_update(CipherSuite cipher_suite,
|
||||
const bytes& group_id,
|
||||
LeafIndex leaf_index,
|
||||
HPKEPublicKey encryption_key,
|
||||
const LeafNodeOptions& opts,
|
||||
const SignaturePrivateKey& sig_priv_in) const;
|
||||
|
||||
LeafNode for_commit(CipherSuite cipher_suite,
|
||||
const bytes& group_id,
|
||||
LeafIndex leaf_index,
|
||||
HPKEPublicKey encryption_key,
|
||||
const bytes& parent_hash,
|
||||
const LeafNodeOptions& opts,
|
||||
const SignaturePrivateKey& sig_priv_in) const;
|
||||
|
||||
void set_capabilities(Capabilities capabilities_in);
|
||||
|
||||
LeafNodeSource source() const;
|
||||
|
||||
struct MemberBinding
|
||||
{
|
||||
bytes group_id;
|
||||
LeafIndex leaf_index;
|
||||
TLS_SERIALIZABLE(group_id, leaf_index);
|
||||
};
|
||||
|
||||
void sign(CipherSuite cipher_suite,
|
||||
const SignaturePrivateKey& sig_priv,
|
||||
const std::optional<MemberBinding>& binding);
|
||||
bool verify(CipherSuite cipher_suite,
|
||||
const std::optional<MemberBinding>& binding) const;
|
||||
|
||||
bool verify_expiry(uint64_t now) const;
|
||||
bool verify_extension_support(const ExtensionList& ext_list) const;
|
||||
|
||||
TLS_SERIALIZABLE(encryption_key,
|
||||
signature_key,
|
||||
credential,
|
||||
capabilities,
|
||||
content,
|
||||
extensions,
|
||||
signature)
|
||||
TLS_TRAITS(tls::pass,
|
||||
tls::pass,
|
||||
tls::pass,
|
||||
tls::pass,
|
||||
tls::variant<LeafNodeSource>,
|
||||
tls::pass,
|
||||
tls::pass)
|
||||
|
||||
private:
|
||||
LeafNode clone_with_options(HPKEPublicKey encryption_key,
|
||||
const LeafNodeOptions& opts) const;
|
||||
bytes to_be_signed(const std::optional<MemberBinding>& binding) const;
|
||||
};
|
||||
|
||||
// Concrete extension types
|
||||
struct RequiredCapabilitiesExtension
|
||||
{
|
||||
std::vector<Extension::Type> extensions;
|
||||
std::vector<uint16_t> proposals;
|
||||
|
||||
static const Extension::Type type;
|
||||
TLS_SERIALIZABLE(extensions, proposals)
|
||||
};
|
||||
|
||||
struct ApplicationIDExtension
|
||||
{
|
||||
bytes id;
|
||||
|
||||
static const Extension::Type type;
|
||||
TLS_SERIALIZABLE(id)
|
||||
};
|
||||
|
||||
///
|
||||
/// NodeType, ParentNode, and KeyPackage
|
||||
///
|
||||
|
||||
// TODO move this to treekem.h
|
||||
struct ParentNode
|
||||
{
|
||||
HPKEPublicKey public_key;
|
||||
bytes parent_hash;
|
||||
std::vector<LeafIndex> unmerged_leaves;
|
||||
|
||||
bytes hash(CipherSuite suite) const;
|
||||
|
||||
TLS_SERIALIZABLE(public_key, parent_hash, unmerged_leaves)
|
||||
};
|
||||
|
||||
// TODO Move this to messages.h
|
||||
// struct {
|
||||
// ProtocolVersion version;
|
||||
// CipherSuite cipher_suite;
|
||||
// HPKEPublicKey init_key;
|
||||
// LeafNode leaf_node;
|
||||
// Extension extensions<V>;
|
||||
// // SignWithLabel(., "KeyPackageTBS", KeyPackageTBS)
|
||||
// opaque signature<V>;
|
||||
// } KeyPackage;
|
||||
struct KeyPackage
|
||||
{
|
||||
ProtocolVersion version;
|
||||
CipherSuite cipher_suite;
|
||||
HPKEPublicKey init_key;
|
||||
LeafNode leaf_node;
|
||||
ExtensionList extensions;
|
||||
bytes signature;
|
||||
|
||||
KeyPackage();
|
||||
KeyPackage(CipherSuite suite_in,
|
||||
HPKEPublicKey init_key_in,
|
||||
LeafNode leaf_node_in,
|
||||
ExtensionList extensions_in,
|
||||
const SignaturePrivateKey& sig_priv_in);
|
||||
|
||||
KeyPackageRef ref() const;
|
||||
|
||||
void sign(const SignaturePrivateKey& sig_priv);
|
||||
bool verify() const;
|
||||
|
||||
TLS_SERIALIZABLE(version,
|
||||
cipher_suite,
|
||||
init_key,
|
||||
leaf_node,
|
||||
extensions,
|
||||
signature)
|
||||
|
||||
private:
|
||||
bytes to_be_signed() const;
|
||||
};
|
||||
|
||||
///
|
||||
/// UpdatePath
|
||||
///
|
||||
|
||||
// struct {
|
||||
// HPKEPublicKey public_key;
|
||||
// HPKECiphertext encrypted_path_secret<V>;
|
||||
// } UpdatePathNode;
|
||||
struct UpdatePathNode
|
||||
{
|
||||
HPKEPublicKey public_key;
|
||||
std::vector<HPKECiphertext> encrypted_path_secret;
|
||||
|
||||
TLS_SERIALIZABLE(public_key, encrypted_path_secret)
|
||||
};
|
||||
|
||||
// struct {
|
||||
// LeafNode leaf_node;
|
||||
// UpdatePathNode nodes<V>;
|
||||
// } UpdatePath;
|
||||
struct UpdatePath
|
||||
{
|
||||
LeafNode leaf_node;
|
||||
std::vector<UpdatePathNode> nodes;
|
||||
|
||||
TLS_SERIALIZABLE(leaf_node, nodes)
|
||||
};
|
||||
|
||||
} // namespace mlspp
|
||||
|
||||
namespace mlspp::tls {
|
||||
|
||||
TLS_VARIANT_MAP(mlspp::LeafNodeSource,
|
||||
mlspp::Lifetime,
|
||||
key_package)
|
||||
TLS_VARIANT_MAP(mlspp::LeafNodeSource, mlspp::Empty, update)
|
||||
TLS_VARIANT_MAP(mlspp::LeafNodeSource,
|
||||
mlspp::ParentHash,
|
||||
commit)
|
||||
|
||||
} // namespace mlspp::tls
|
||||
228
DPP/mlspp/include/mls/credential.h
Executable file
228
DPP/mlspp/include/mls/credential.h
Executable file
@@ -0,0 +1,228 @@
|
||||
#pragma once
|
||||
|
||||
#include <mls/common.h>
|
||||
#include <mls/crypto.h>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
namespace hpke {
|
||||
struct UserInfoVC;
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque identity<0..2^16-1>;
|
||||
// SignaturePublicKey public_key;
|
||||
// } BasicCredential;
|
||||
struct BasicCredential
|
||||
{
|
||||
BasicCredential() {}
|
||||
|
||||
BasicCredential(bytes identity_in)
|
||||
: identity(std::move(identity_in))
|
||||
{
|
||||
}
|
||||
|
||||
bytes identity;
|
||||
|
||||
TLS_SERIALIZABLE(identity)
|
||||
};
|
||||
|
||||
struct X509Credential
|
||||
{
|
||||
struct CertData
|
||||
{
|
||||
bytes data;
|
||||
|
||||
TLS_SERIALIZABLE(data)
|
||||
};
|
||||
|
||||
X509Credential() = default;
|
||||
explicit X509Credential(const std::vector<bytes>& der_chain_in);
|
||||
|
||||
SignatureScheme signature_scheme() const;
|
||||
SignaturePublicKey public_key() const;
|
||||
bool valid_for(const SignaturePublicKey& pub) const;
|
||||
|
||||
// TODO(rlb) This should be const or exposed via a method
|
||||
std::vector<CertData> der_chain;
|
||||
|
||||
private:
|
||||
SignaturePublicKey _public_key;
|
||||
SignatureScheme _signature_scheme;
|
||||
};
|
||||
|
||||
tls::ostream&
|
||||
operator<<(tls::ostream& str, const X509Credential& obj);
|
||||
|
||||
tls::istream&
|
||||
operator>>(tls::istream& str, X509Credential& obj);
|
||||
|
||||
struct UserInfoVCCredential
|
||||
{
|
||||
UserInfoVCCredential() = default;
|
||||
explicit UserInfoVCCredential(std::string userinfo_vc_jwt_in);
|
||||
|
||||
std::string userinfo_vc_jwt;
|
||||
|
||||
bool valid_for(const SignaturePublicKey& pub) const;
|
||||
bool valid_from(const PublicJWK& pub) const;
|
||||
|
||||
friend tls::ostream operator<<(tls::ostream& str,
|
||||
const UserInfoVCCredential& obj);
|
||||
friend tls::istream operator>>(tls::istream& str, UserInfoVCCredential& obj);
|
||||
friend bool operator==(const UserInfoVCCredential& lhs,
|
||||
const UserInfoVCCredential& rhs);
|
||||
friend bool operator!=(const UserInfoVCCredential& lhs,
|
||||
const UserInfoVCCredential& rhs);
|
||||
|
||||
private:
|
||||
std::shared_ptr<hpke::UserInfoVC> _vc;
|
||||
};
|
||||
|
||||
bool
|
||||
operator==(const X509Credential& lhs, const X509Credential& rhs);
|
||||
|
||||
enum struct CredentialType : uint16_t
|
||||
{
|
||||
reserved = 0,
|
||||
basic = 1,
|
||||
x509 = 2,
|
||||
|
||||
userinfo_vc_draft_00 = 0xFE00,
|
||||
multi_draft_00 = 0xFF00,
|
||||
|
||||
// GREASE values, included here mainly so that debugger output looks nice
|
||||
GREASE_0 = 0x0A0A,
|
||||
GREASE_1 = 0x1A1A,
|
||||
GREASE_2 = 0x2A2A,
|
||||
GREASE_3 = 0x3A3A,
|
||||
GREASE_4 = 0x4A4A,
|
||||
GREASE_5 = 0x5A5A,
|
||||
GREASE_6 = 0x6A6A,
|
||||
GREASE_7 = 0x7A7A,
|
||||
GREASE_8 = 0x8A8A,
|
||||
GREASE_9 = 0x9A9A,
|
||||
GREASE_A = 0xAAAA,
|
||||
GREASE_B = 0xBABA,
|
||||
GREASE_C = 0xCACA,
|
||||
GREASE_D = 0xDADA,
|
||||
GREASE_E = 0xEAEA,
|
||||
};
|
||||
|
||||
// struct {
|
||||
// Credential credential;
|
||||
// SignaturePublicKey credential_key;
|
||||
// opaque signature<V>;
|
||||
// } CredentialBinding
|
||||
//
|
||||
// struct {
|
||||
// CredentialBinding bindings<V>;
|
||||
// } MultiCredential;
|
||||
struct CredentialBinding;
|
||||
struct CredentialBindingInput;
|
||||
|
||||
struct MultiCredential
|
||||
{
|
||||
MultiCredential() = default;
|
||||
MultiCredential(const std::vector<CredentialBindingInput>& binding_inputs,
|
||||
const SignaturePublicKey& signature_key);
|
||||
|
||||
std::vector<CredentialBinding> bindings;
|
||||
|
||||
bool valid_for(const SignaturePublicKey& pub) const;
|
||||
|
||||
TLS_SERIALIZABLE(bindings)
|
||||
};
|
||||
|
||||
// struct {
|
||||
// CredentialType credential_type;
|
||||
// select (credential_type) {
|
||||
// case basic:
|
||||
// BasicCredential;
|
||||
//
|
||||
// case x509:
|
||||
// opaque cert_data<1..2^24-1>;
|
||||
// };
|
||||
// } Credential;
|
||||
struct Credential
|
||||
{
|
||||
Credential() = default;
|
||||
|
||||
CredentialType type() const;
|
||||
|
||||
template<typename T>
|
||||
const T& get() const
|
||||
{
|
||||
return var::get<T>(_cred);
|
||||
}
|
||||
|
||||
static Credential basic(const bytes& identity);
|
||||
static Credential x509(const std::vector<bytes>& der_chain);
|
||||
static Credential userinfo_vc(const std::string& userinfo_vc_jwt);
|
||||
static Credential multi(
|
||||
const std::vector<CredentialBindingInput>& binding_inputs,
|
||||
const SignaturePublicKey& signature_key);
|
||||
|
||||
bool valid_for(const SignaturePublicKey& pub) const;
|
||||
|
||||
TLS_SERIALIZABLE(_cred)
|
||||
TLS_TRAITS(tls::variant<CredentialType>)
|
||||
|
||||
private:
|
||||
using SpecificCredential = var::variant<BasicCredential,
|
||||
X509Credential,
|
||||
UserInfoVCCredential,
|
||||
MultiCredential>;
|
||||
|
||||
Credential(SpecificCredential specific);
|
||||
SpecificCredential _cred;
|
||||
};
|
||||
|
||||
// XXX(RLB): This struct needs to appear below Credential so that all types are
|
||||
// concrete at the appropriate points.
|
||||
struct CredentialBindingInput
|
||||
{
|
||||
CipherSuite cipher_suite;
|
||||
Credential credential;
|
||||
const SignaturePrivateKey& credential_priv;
|
||||
};
|
||||
|
||||
struct CredentialBinding
|
||||
{
|
||||
CipherSuite cipher_suite;
|
||||
Credential credential;
|
||||
SignaturePublicKey credential_key;
|
||||
bytes signature;
|
||||
|
||||
CredentialBinding() = default;
|
||||
CredentialBinding(CipherSuite suite_in,
|
||||
Credential credential_in,
|
||||
const SignaturePrivateKey& credential_priv,
|
||||
const SignaturePublicKey& signature_key);
|
||||
|
||||
bool valid_for(const SignaturePublicKey& signature_key) const;
|
||||
|
||||
TLS_SERIALIZABLE(cipher_suite, credential, credential_key, signature)
|
||||
|
||||
private:
|
||||
bytes to_be_signed(const SignaturePublicKey& signature_key) const;
|
||||
};
|
||||
|
||||
} // namespace mlspp
|
||||
|
||||
namespace mlspp::tls {
|
||||
|
||||
TLS_VARIANT_MAP(mlspp::CredentialType,
|
||||
mlspp::BasicCredential,
|
||||
basic)
|
||||
TLS_VARIANT_MAP(mlspp::CredentialType,
|
||||
mlspp::X509Credential,
|
||||
x509)
|
||||
TLS_VARIANT_MAP(mlspp::CredentialType,
|
||||
mlspp::UserInfoVCCredential,
|
||||
userinfo_vc_draft_00)
|
||||
TLS_VARIANT_MAP(mlspp::CredentialType,
|
||||
mlspp::MultiCredential,
|
||||
multi_draft_00)
|
||||
|
||||
} // namespace mlspp::tls
|
||||
266
DPP/mlspp/include/mls/crypto.h
Executable file
266
DPP/mlspp/include/mls/crypto.h
Executable file
@@ -0,0 +1,266 @@
|
||||
#pragma once
|
||||
|
||||
#include <hpke/digest.h>
|
||||
#include <hpke/hpke.h>
|
||||
#include <hpke/random.h>
|
||||
#include <hpke/signature.h>
|
||||
#include <mls/common.h>
|
||||
#include <tls/tls_syntax.h>
|
||||
#include <vector>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
/// Signature Code points, borrowed from RFC 8446
|
||||
enum struct SignatureScheme : uint16_t
|
||||
{
|
||||
ecdsa_secp256r1_sha256 = 0x0403,
|
||||
ecdsa_secp384r1_sha384 = 0x0805,
|
||||
ecdsa_secp521r1_sha512 = 0x0603,
|
||||
ed25519 = 0x0807,
|
||||
ed448 = 0x0808,
|
||||
rsa_pkcs1_sha256 = 0x0401,
|
||||
};
|
||||
|
||||
SignatureScheme
|
||||
tls_signature_scheme(hpke::Signature::ID id);
|
||||
|
||||
/// Cipher suites
|
||||
|
||||
struct KeyAndNonce
|
||||
{
|
||||
bytes key;
|
||||
bytes nonce;
|
||||
};
|
||||
|
||||
// opaque HashReference<V>;
|
||||
// HashReference KeyPackageRef;
|
||||
// HashReference ProposalRef;
|
||||
using HashReference = bytes;
|
||||
using KeyPackageRef = HashReference;
|
||||
using ProposalRef = HashReference;
|
||||
|
||||
struct CipherSuite
|
||||
{
|
||||
enum struct ID : uint16_t
|
||||
{
|
||||
unknown = 0x0000,
|
||||
X25519_AES128GCM_SHA256_Ed25519 = 0x0001,
|
||||
P256_AES128GCM_SHA256_P256 = 0x0002,
|
||||
X25519_CHACHA20POLY1305_SHA256_Ed25519 = 0x0003,
|
||||
X448_AES256GCM_SHA512_Ed448 = 0x0004,
|
||||
P521_AES256GCM_SHA512_P521 = 0x0005,
|
||||
X448_CHACHA20POLY1305_SHA512_Ed448 = 0x0006,
|
||||
P384_AES256GCM_SHA384_P384 = 0x0007,
|
||||
|
||||
// GREASE values, included here mainly so that debugger output looks nice
|
||||
GREASE_0 = 0x0A0A,
|
||||
GREASE_1 = 0x1A1A,
|
||||
GREASE_2 = 0x2A2A,
|
||||
GREASE_3 = 0x3A3A,
|
||||
GREASE_4 = 0x4A4A,
|
||||
GREASE_5 = 0x5A5A,
|
||||
GREASE_6 = 0x6A6A,
|
||||
GREASE_7 = 0x7A7A,
|
||||
GREASE_8 = 0x8A8A,
|
||||
GREASE_9 = 0x9A9A,
|
||||
GREASE_A = 0xAAAA,
|
||||
GREASE_B = 0xBABA,
|
||||
GREASE_C = 0xCACA,
|
||||
GREASE_D = 0xDADA,
|
||||
GREASE_E = 0xEAEA,
|
||||
};
|
||||
|
||||
CipherSuite();
|
||||
CipherSuite(ID id_in);
|
||||
|
||||
ID cipher_suite() const { return id; }
|
||||
SignatureScheme signature_scheme() const;
|
||||
|
||||
size_t secret_size() const { return get().digest.hash_size; }
|
||||
size_t key_size() const { return get().hpke.aead.key_size; }
|
||||
size_t nonce_size() const { return get().hpke.aead.nonce_size; }
|
||||
|
||||
bytes zero() const { return bytes(secret_size(), 0); }
|
||||
const hpke::HPKE& hpke() const { return get().hpke; }
|
||||
const hpke::Digest& digest() const { return get().digest; }
|
||||
const hpke::Signature& sig() const { return get().sig; }
|
||||
|
||||
bytes expand_with_label(const bytes& secret,
|
||||
const std::string& label,
|
||||
const bytes& context,
|
||||
size_t length) const;
|
||||
bytes derive_secret(const bytes& secret, const std::string& label) const;
|
||||
bytes derive_tree_secret(const bytes& secret,
|
||||
const std::string& label,
|
||||
uint32_t generation,
|
||||
size_t length) const;
|
||||
|
||||
template<typename T>
|
||||
bytes ref(const T& value) const
|
||||
{
|
||||
return raw_ref(reference_label<T>(), tls::marshal(value));
|
||||
}
|
||||
|
||||
bytes raw_ref(const bytes& label, const bytes& value) const
|
||||
{
|
||||
// RefHash(label, value) = Hash(RefHashInput)
|
||||
//
|
||||
// struct {
|
||||
// opaque label<V>;
|
||||
// opaque value<V>;
|
||||
// } RefHashInput;
|
||||
auto w = tls::ostream();
|
||||
w << label << value;
|
||||
return digest().hash(w.bytes());
|
||||
}
|
||||
|
||||
TLS_SERIALIZABLE(id)
|
||||
|
||||
private:
|
||||
ID id;
|
||||
|
||||
struct Ciphers
|
||||
{
|
||||
hpke::HPKE hpke;
|
||||
const hpke::Digest& digest;
|
||||
const hpke::Signature& sig;
|
||||
};
|
||||
|
||||
const Ciphers& get() const;
|
||||
|
||||
template<typename T>
|
||||
static const bytes& reference_label();
|
||||
};
|
||||
|
||||
#if WITH_BORINGSSL
|
||||
extern const std::array<CipherSuite::ID, 5> all_supported_suites;
|
||||
#else
|
||||
extern const std::array<CipherSuite::ID, 7> all_supported_suites;
|
||||
#endif
|
||||
|
||||
// Utilities
|
||||
using mlspp::hpke::random_bytes;
|
||||
|
||||
// HPKE Keys
|
||||
namespace encrypt_label {
|
||||
extern const std::string update_path_node;
|
||||
extern const std::string welcome;
|
||||
} // namespace encrypt_label
|
||||
|
||||
struct HPKECiphertext
|
||||
{
|
||||
bytes kem_output;
|
||||
bytes ciphertext;
|
||||
|
||||
TLS_SERIALIZABLE(kem_output, ciphertext)
|
||||
};
|
||||
|
||||
struct HPKEPublicKey
|
||||
{
|
||||
bytes data;
|
||||
|
||||
HPKECiphertext encrypt(CipherSuite suite,
|
||||
const std::string& label,
|
||||
const bytes& context,
|
||||
const bytes& pt) const;
|
||||
|
||||
std::tuple<bytes, bytes> do_export(CipherSuite suite,
|
||||
const bytes& info,
|
||||
const std::string& label,
|
||||
size_t size) const;
|
||||
|
||||
TLS_SERIALIZABLE(data)
|
||||
};
|
||||
|
||||
struct HPKEPrivateKey
|
||||
{
|
||||
static HPKEPrivateKey generate(CipherSuite suite);
|
||||
static HPKEPrivateKey parse(CipherSuite suite, const bytes& data);
|
||||
static HPKEPrivateKey derive(CipherSuite suite, const bytes& secret);
|
||||
|
||||
HPKEPrivateKey() = default;
|
||||
|
||||
bytes data;
|
||||
HPKEPublicKey public_key;
|
||||
|
||||
bytes decrypt(CipherSuite suite,
|
||||
const std::string& label,
|
||||
const bytes& context,
|
||||
const HPKECiphertext& ct) const;
|
||||
|
||||
bytes do_export(CipherSuite suite,
|
||||
const bytes& info,
|
||||
const bytes& kem_output,
|
||||
const std::string& label,
|
||||
size_t size) const;
|
||||
|
||||
void set_public_key(CipherSuite suite);
|
||||
|
||||
TLS_SERIALIZABLE(data)
|
||||
|
||||
private:
|
||||
HPKEPrivateKey(bytes priv_data, bytes pub_data);
|
||||
};
|
||||
|
||||
// Signature Keys
|
||||
namespace sign_label {
|
||||
extern const std::string mls_content;
|
||||
extern const std::string leaf_node;
|
||||
extern const std::string key_package;
|
||||
extern const std::string group_info;
|
||||
extern const std::string multi_credential;
|
||||
} // namespace sign_label
|
||||
|
||||
struct SignaturePublicKey
|
||||
{
|
||||
static SignaturePublicKey from_jwk(CipherSuite suite,
|
||||
const std::string& json_str);
|
||||
|
||||
bytes data;
|
||||
|
||||
bool verify(const CipherSuite& suite,
|
||||
const std::string& label,
|
||||
const bytes& message,
|
||||
const bytes& signature) const;
|
||||
|
||||
std::string to_jwk(CipherSuite suite) const;
|
||||
|
||||
TLS_SERIALIZABLE(data)
|
||||
};
|
||||
|
||||
struct PublicJWK
|
||||
{
|
||||
SignatureScheme signature_scheme;
|
||||
std::optional<std::string> key_id;
|
||||
SignaturePublicKey public_key;
|
||||
|
||||
static PublicJWK parse(const std::string& jwk_json);
|
||||
};
|
||||
|
||||
struct SignaturePrivateKey
|
||||
{
|
||||
static SignaturePrivateKey generate(CipherSuite suite);
|
||||
static SignaturePrivateKey parse(CipherSuite suite, const bytes& data);
|
||||
static SignaturePrivateKey derive(CipherSuite suite, const bytes& secret);
|
||||
static SignaturePrivateKey from_jwk(CipherSuite suite,
|
||||
const std::string& json_str);
|
||||
|
||||
SignaturePrivateKey() = default;
|
||||
|
||||
bytes data;
|
||||
SignaturePublicKey public_key;
|
||||
|
||||
bytes sign(const CipherSuite& suite,
|
||||
const std::string& label,
|
||||
const bytes& message) const;
|
||||
|
||||
void set_public_key(CipherSuite suite);
|
||||
std::string to_jwk(CipherSuite suite) const;
|
||||
|
||||
TLS_SERIALIZABLE(data)
|
||||
|
||||
private:
|
||||
SignaturePrivateKey(bytes priv_data, bytes pub_data);
|
||||
};
|
||||
|
||||
} // namespace mlspp
|
||||
205
DPP/mlspp/include/mls/key_schedule.h
Executable file
205
DPP/mlspp/include/mls/key_schedule.h
Executable file
@@ -0,0 +1,205 @@
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <mls/common.h>
|
||||
#include <mls/crypto.h>
|
||||
#include <mls/messages.h>
|
||||
#include <mls/tree_math.h>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
struct HashRatchet
|
||||
{
|
||||
CipherSuite suite;
|
||||
bytes next_secret;
|
||||
uint32_t next_generation;
|
||||
std::map<uint32_t, KeyAndNonce> cache;
|
||||
|
||||
size_t key_size;
|
||||
size_t nonce_size;
|
||||
size_t secret_size;
|
||||
|
||||
// These defaults are necessary for use with containers
|
||||
HashRatchet() = default;
|
||||
HashRatchet(const HashRatchet& other) = default;
|
||||
HashRatchet(HashRatchet&& other) = default;
|
||||
HashRatchet& operator=(const HashRatchet& other) = default;
|
||||
HashRatchet& operator=(HashRatchet&& other) = default;
|
||||
|
||||
HashRatchet(CipherSuite suite_in, bytes base_secret_in);
|
||||
|
||||
std::tuple<uint32_t, KeyAndNonce> next();
|
||||
KeyAndNonce get(uint32_t generation);
|
||||
void erase(uint32_t generation);
|
||||
};
|
||||
|
||||
struct SecretTree
|
||||
{
|
||||
SecretTree() = default;
|
||||
SecretTree(CipherSuite suite_in,
|
||||
LeafCount group_size_in,
|
||||
bytes encryption_secret_in);
|
||||
|
||||
bool has_leaf(LeafIndex sender) { return sender < group_size; }
|
||||
|
||||
bytes get(LeafIndex sender);
|
||||
|
||||
private:
|
||||
CipherSuite suite;
|
||||
LeafCount group_size;
|
||||
NodeIndex root;
|
||||
std::map<NodeIndex, bytes> secrets;
|
||||
size_t secret_size;
|
||||
};
|
||||
|
||||
using ReuseGuard = std::array<uint8_t, 4>;
|
||||
|
||||
struct GroupKeySource
|
||||
{
|
||||
enum struct RatchetType
|
||||
{
|
||||
handshake,
|
||||
application,
|
||||
};
|
||||
|
||||
GroupKeySource() = default;
|
||||
GroupKeySource(CipherSuite suite_in,
|
||||
LeafCount group_size,
|
||||
bytes encryption_secret);
|
||||
|
||||
bool has_leaf(LeafIndex sender) { return secret_tree.has_leaf(sender); }
|
||||
|
||||
std::tuple<uint32_t, ReuseGuard, KeyAndNonce> next(ContentType content_type,
|
||||
LeafIndex sender);
|
||||
KeyAndNonce get(ContentType content_type,
|
||||
LeafIndex sender,
|
||||
uint32_t generation,
|
||||
ReuseGuard reuse_guard);
|
||||
void erase(ContentType type, LeafIndex sender, uint32_t generation);
|
||||
|
||||
private:
|
||||
CipherSuite suite;
|
||||
SecretTree secret_tree;
|
||||
|
||||
using Key = std::tuple<RatchetType, LeafIndex>;
|
||||
std::map<Key, HashRatchet> chains;
|
||||
|
||||
HashRatchet& chain(RatchetType type, LeafIndex sender);
|
||||
HashRatchet& chain(ContentType type, LeafIndex sender);
|
||||
|
||||
static const std::array<RatchetType, 2> all_ratchet_types;
|
||||
};
|
||||
|
||||
struct KeyScheduleEpoch
|
||||
{
|
||||
private:
|
||||
CipherSuite suite;
|
||||
|
||||
public:
|
||||
bytes joiner_secret;
|
||||
bytes epoch_secret;
|
||||
|
||||
bytes sender_data_secret;
|
||||
bytes encryption_secret;
|
||||
bytes exporter_secret;
|
||||
bytes epoch_authenticator;
|
||||
bytes external_secret;
|
||||
bytes confirmation_key;
|
||||
bytes membership_key;
|
||||
bytes resumption_psk;
|
||||
bytes init_secret;
|
||||
|
||||
HPKEPrivateKey external_priv;
|
||||
|
||||
KeyScheduleEpoch() = default;
|
||||
|
||||
// Full initializer, used by invited joiner
|
||||
static KeyScheduleEpoch joiner(CipherSuite suite_in,
|
||||
const bytes& joiner_secret,
|
||||
const std::vector<PSKWithSecret>& psks,
|
||||
const bytes& context);
|
||||
|
||||
// Ciphersuite-only initializer, used by external joiner
|
||||
KeyScheduleEpoch(CipherSuite suite_in);
|
||||
|
||||
// Initial epoch
|
||||
KeyScheduleEpoch(CipherSuite suite_in,
|
||||
const bytes& init_secret,
|
||||
const bytes& context);
|
||||
|
||||
static std::tuple<bytes, bytes> external_init(
|
||||
CipherSuite suite,
|
||||
const HPKEPublicKey& external_pub);
|
||||
bytes receive_external_init(const bytes& kem_output) const;
|
||||
|
||||
KeyScheduleEpoch next(const bytes& commit_secret,
|
||||
const std::vector<PSKWithSecret>& psks,
|
||||
const std::optional<bytes>& force_init_secret,
|
||||
const bytes& context) const;
|
||||
|
||||
GroupKeySource encryption_keys(LeafCount size) const;
|
||||
bytes confirmation_tag(const bytes& confirmed_transcript_hash) const;
|
||||
bytes do_export(const std::string& label,
|
||||
const bytes& context,
|
||||
size_t size) const;
|
||||
PSKWithSecret resumption_psk_w_secret(ResumptionPSKUsage usage,
|
||||
const bytes& group_id,
|
||||
epoch_t epoch);
|
||||
|
||||
static bytes make_psk_secret(CipherSuite suite,
|
||||
const std::vector<PSKWithSecret>& psks);
|
||||
static bytes welcome_secret(CipherSuite suite,
|
||||
const bytes& joiner_secret,
|
||||
const std::vector<PSKWithSecret>& psks);
|
||||
static KeyAndNonce sender_data_keys(CipherSuite suite,
|
||||
const bytes& sender_data_secret,
|
||||
const bytes& ciphertext);
|
||||
|
||||
// TODO(RLB) make these methods private, but accessible to test vectors
|
||||
KeyScheduleEpoch(CipherSuite suite_in,
|
||||
const bytes& init_secret,
|
||||
const bytes& commit_secret,
|
||||
const bytes& psk_secret,
|
||||
const bytes& context);
|
||||
KeyScheduleEpoch next_raw(const bytes& commit_secret,
|
||||
const bytes& psk_secret,
|
||||
const std::optional<bytes>& force_init_secret,
|
||||
const bytes& context) const;
|
||||
static bytes welcome_secret_raw(CipherSuite suite,
|
||||
const bytes& joiner_secret,
|
||||
const bytes& psk_secret);
|
||||
|
||||
private:
|
||||
KeyScheduleEpoch(CipherSuite suite_in,
|
||||
const bytes& joiner_secret,
|
||||
const bytes& psk_secret,
|
||||
const bytes& context);
|
||||
};
|
||||
|
||||
bool
|
||||
operator==(const KeyScheduleEpoch& lhs, const KeyScheduleEpoch& rhs);
|
||||
|
||||
struct TranscriptHash
|
||||
{
|
||||
CipherSuite suite;
|
||||
bytes confirmed;
|
||||
bytes interim;
|
||||
|
||||
// For a new group
|
||||
TranscriptHash(CipherSuite suite_in);
|
||||
|
||||
// For joining a group
|
||||
TranscriptHash(CipherSuite suite_in,
|
||||
bytes confirmed_in,
|
||||
const bytes& confirmation_tag);
|
||||
|
||||
void update(const AuthenticatedContent& content_auth);
|
||||
void update_confirmed(const AuthenticatedContent& content_auth);
|
||||
void update_interim(const bytes& confirmation_tag);
|
||||
void update_interim(const AuthenticatedContent& content_auth);
|
||||
};
|
||||
|
||||
bool
|
||||
operator==(const TranscriptHash& lhs, const TranscriptHash& rhs);
|
||||
|
||||
} // namespace mlspp
|
||||
752
DPP/mlspp/include/mls/messages.h
Executable file
752
DPP/mlspp/include/mls/messages.h
Executable file
@@ -0,0 +1,752 @@
|
||||
#pragma once
|
||||
|
||||
#include "mls/common.h"
|
||||
#include "mls/core_types.h"
|
||||
#include "mls/credential.h"
|
||||
#include "mls/crypto.h"
|
||||
#include "mls/treekem.h"
|
||||
#include <optional>
|
||||
#include <tls/tls_syntax.h>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
struct ExternalPubExtension
|
||||
{
|
||||
HPKEPublicKey external_pub;
|
||||
|
||||
static const uint16_t type;
|
||||
TLS_SERIALIZABLE(external_pub)
|
||||
};
|
||||
|
||||
struct RatchetTreeExtension
|
||||
{
|
||||
TreeKEMPublicKey tree;
|
||||
|
||||
static const uint16_t type;
|
||||
TLS_SERIALIZABLE(tree)
|
||||
};
|
||||
|
||||
struct ExternalSender
|
||||
{
|
||||
SignaturePublicKey signature_key;
|
||||
Credential credential;
|
||||
|
||||
TLS_SERIALIZABLE(signature_key, credential);
|
||||
};
|
||||
|
||||
struct ExternalSendersExtension
|
||||
{
|
||||
std::vector<ExternalSender> senders;
|
||||
|
||||
static const uint16_t type;
|
||||
TLS_SERIALIZABLE(senders);
|
||||
};
|
||||
|
||||
struct SFrameParameters
|
||||
{
|
||||
uint16_t cipher_suite;
|
||||
uint8_t epoch_bits;
|
||||
|
||||
static const uint16_t type;
|
||||
TLS_SERIALIZABLE(cipher_suite, epoch_bits)
|
||||
};
|
||||
|
||||
struct SFrameCapabilities
|
||||
{
|
||||
std::vector<uint16_t> cipher_suites;
|
||||
|
||||
bool compatible(const SFrameParameters& params) const;
|
||||
|
||||
static const uint16_t type;
|
||||
TLS_SERIALIZABLE(cipher_suites)
|
||||
};
|
||||
|
||||
///
|
||||
/// PSKs
|
||||
///
|
||||
enum struct PSKType : uint8_t
|
||||
{
|
||||
reserved = 0,
|
||||
external = 1,
|
||||
resumption = 2,
|
||||
};
|
||||
|
||||
struct ExternalPSK
|
||||
{
|
||||
bytes psk_id;
|
||||
TLS_SERIALIZABLE(psk_id)
|
||||
};
|
||||
|
||||
enum struct ResumptionPSKUsage : uint8_t
|
||||
{
|
||||
reserved = 0,
|
||||
application = 1,
|
||||
reinit = 2,
|
||||
branch = 3,
|
||||
};
|
||||
|
||||
struct ResumptionPSK
|
||||
{
|
||||
ResumptionPSKUsage usage;
|
||||
bytes psk_group_id;
|
||||
epoch_t psk_epoch;
|
||||
TLS_SERIALIZABLE(usage, psk_group_id, psk_epoch)
|
||||
};
|
||||
|
||||
struct PreSharedKeyID
|
||||
{
|
||||
var::variant<ExternalPSK, ResumptionPSK> content;
|
||||
bytes psk_nonce;
|
||||
TLS_SERIALIZABLE(content, psk_nonce)
|
||||
TLS_TRAITS(tls::variant<PSKType>, tls::pass)
|
||||
};
|
||||
|
||||
struct PreSharedKeys
|
||||
{
|
||||
std::vector<PreSharedKeyID> psks;
|
||||
TLS_SERIALIZABLE(psks)
|
||||
};
|
||||
|
||||
struct PSKWithSecret
|
||||
{
|
||||
PreSharedKeyID id;
|
||||
bytes secret;
|
||||
};
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion version = mls10;
|
||||
// CipherSuite cipher_suite;
|
||||
// opaque group_id<V>;
|
||||
// uint64 epoch;
|
||||
// opaque tree_hash<V>;
|
||||
// opaque confirmed_transcript_hash<V>;
|
||||
// Extension extensions<V>;
|
||||
// } GroupContext;
|
||||
struct GroupContext
|
||||
{
|
||||
ProtocolVersion version{ ProtocolVersion::mls10 };
|
||||
CipherSuite cipher_suite;
|
||||
bytes group_id;
|
||||
epoch_t epoch;
|
||||
bytes tree_hash;
|
||||
bytes confirmed_transcript_hash;
|
||||
ExtensionList extensions;
|
||||
|
||||
GroupContext() = default;
|
||||
GroupContext(CipherSuite cipher_suite_in,
|
||||
bytes group_id_in,
|
||||
epoch_t epoch_in,
|
||||
bytes tree_hash_in,
|
||||
bytes confirmed_transcript_hash_in,
|
||||
ExtensionList extensions_in);
|
||||
|
||||
TLS_SERIALIZABLE(version,
|
||||
cipher_suite,
|
||||
group_id,
|
||||
epoch,
|
||||
tree_hash,
|
||||
confirmed_transcript_hash,
|
||||
extensions)
|
||||
};
|
||||
|
||||
// struct {
|
||||
// GroupContext group_context;
|
||||
// Extension extensions<V>;
|
||||
// MAC confirmation_tag;
|
||||
// uint32 signer;
|
||||
// // SignWithLabel(., "GroupInfoTBS", GroupInfoTBS)
|
||||
// opaque signature<V>;
|
||||
// } GroupInfo;
|
||||
struct GroupInfo
|
||||
{
|
||||
GroupContext group_context;
|
||||
ExtensionList extensions;
|
||||
bytes confirmation_tag;
|
||||
LeafIndex signer;
|
||||
bytes signature;
|
||||
|
||||
GroupInfo() = default;
|
||||
GroupInfo(GroupContext group_context_in,
|
||||
ExtensionList extensions_in,
|
||||
bytes confirmation_tag_in);
|
||||
|
||||
bytes to_be_signed() const;
|
||||
void sign(const TreeKEMPublicKey& tree,
|
||||
LeafIndex signer_index,
|
||||
const SignaturePrivateKey& priv);
|
||||
bool verify(const TreeKEMPublicKey& tree) const;
|
||||
|
||||
// These methods exist only to simplify unit testing
|
||||
void sign(LeafIndex signer_index, const SignaturePrivateKey& priv);
|
||||
bool verify(const SignaturePublicKey& pub) const;
|
||||
|
||||
TLS_SERIALIZABLE(group_context,
|
||||
extensions,
|
||||
confirmation_tag,
|
||||
signer,
|
||||
signature)
|
||||
};
|
||||
|
||||
// struct {
|
||||
// opaque joiner_secret<1..255>;
|
||||
// optional<PathSecret> path_secret;
|
||||
// PreSharedKeys psks;
|
||||
// } GroupSecrets;
|
||||
struct GroupSecrets
|
||||
{
|
||||
struct PathSecret
|
||||
{
|
||||
bytes secret;
|
||||
|
||||
TLS_SERIALIZABLE(secret)
|
||||
};
|
||||
|
||||
bytes joiner_secret;
|
||||
std::optional<PathSecret> path_secret;
|
||||
PreSharedKeys psks;
|
||||
|
||||
TLS_SERIALIZABLE(joiner_secret, path_secret, psks)
|
||||
};
|
||||
|
||||
// struct {
|
||||
// opaque key_package_hash<1..255>;
|
||||
// HPKECiphertext encrypted_group_secrets;
|
||||
// } EncryptedGroupSecrets;
|
||||
struct EncryptedGroupSecrets
|
||||
{
|
||||
KeyPackageRef new_member;
|
||||
HPKECiphertext encrypted_group_secrets;
|
||||
|
||||
TLS_SERIALIZABLE(new_member, encrypted_group_secrets)
|
||||
};
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion version = mls10;
|
||||
// CipherSuite cipher_suite;
|
||||
// EncryptedGroupSecrets group_secretss<1..2^32-1>;
|
||||
// opaque encrypted_group_info<1..2^32-1>;
|
||||
// } Welcome;
|
||||
struct Welcome
|
||||
{
|
||||
CipherSuite cipher_suite;
|
||||
std::vector<EncryptedGroupSecrets> secrets;
|
||||
bytes encrypted_group_info;
|
||||
|
||||
Welcome();
|
||||
Welcome(CipherSuite suite,
|
||||
const bytes& joiner_secret,
|
||||
const std::vector<PSKWithSecret>& psks,
|
||||
const GroupInfo& group_info);
|
||||
|
||||
void encrypt(const KeyPackage& kp, const std::optional<bytes>& path_secret);
|
||||
std::optional<int> find(const KeyPackage& kp) const;
|
||||
GroupSecrets decrypt_secrets(int kp_index,
|
||||
const HPKEPrivateKey& init_priv) const;
|
||||
GroupInfo decrypt(const bytes& joiner_secret,
|
||||
const std::vector<PSKWithSecret>& psks) const;
|
||||
|
||||
TLS_SERIALIZABLE(cipher_suite, secrets, encrypted_group_info)
|
||||
|
||||
private:
|
||||
bytes _joiner_secret;
|
||||
PreSharedKeys _psks;
|
||||
static KeyAndNonce group_info_key_nonce(
|
||||
CipherSuite suite,
|
||||
const bytes& joiner_secret,
|
||||
const std::vector<PSKWithSecret>& psks);
|
||||
};
|
||||
|
||||
///
|
||||
/// Proposals & Commit
|
||||
///
|
||||
|
||||
// Add
|
||||
struct Add
|
||||
{
|
||||
KeyPackage key_package;
|
||||
TLS_SERIALIZABLE(key_package)
|
||||
};
|
||||
|
||||
// Update
|
||||
struct Update
|
||||
{
|
||||
LeafNode leaf_node;
|
||||
TLS_SERIALIZABLE(leaf_node)
|
||||
};
|
||||
|
||||
// Remove
|
||||
struct Remove
|
||||
{
|
||||
LeafIndex removed;
|
||||
TLS_SERIALIZABLE(removed)
|
||||
};
|
||||
|
||||
// PreSharedKey
|
||||
struct PreSharedKey
|
||||
{
|
||||
PreSharedKeyID psk;
|
||||
TLS_SERIALIZABLE(psk)
|
||||
};
|
||||
|
||||
// ReInit
|
||||
struct ReInit
|
||||
{
|
||||
bytes group_id;
|
||||
ProtocolVersion version;
|
||||
CipherSuite cipher_suite;
|
||||
ExtensionList extensions;
|
||||
|
||||
TLS_SERIALIZABLE(group_id, version, cipher_suite, extensions)
|
||||
};
|
||||
|
||||
// ExternalInit
|
||||
struct ExternalInit
|
||||
{
|
||||
bytes kem_output;
|
||||
TLS_SERIALIZABLE(kem_output)
|
||||
};
|
||||
|
||||
// GroupContextExtensions
|
||||
struct GroupContextExtensions
|
||||
{
|
||||
ExtensionList group_context_extensions;
|
||||
TLS_SERIALIZABLE(group_context_extensions)
|
||||
};
|
||||
|
||||
struct ProposalType;
|
||||
|
||||
struct Proposal
|
||||
{
|
||||
using Type = uint16_t;
|
||||
|
||||
var::variant<Add,
|
||||
Update,
|
||||
Remove,
|
||||
PreSharedKey,
|
||||
ReInit,
|
||||
ExternalInit,
|
||||
GroupContextExtensions>
|
||||
content;
|
||||
|
||||
Type proposal_type() const;
|
||||
|
||||
TLS_SERIALIZABLE(content)
|
||||
TLS_TRAITS(tls::variant<ProposalType>)
|
||||
};
|
||||
|
||||
struct ProposalType
|
||||
{
|
||||
static constexpr Proposal::Type invalid = 0;
|
||||
static constexpr Proposal::Type add = 1;
|
||||
static constexpr Proposal::Type update = 2;
|
||||
static constexpr Proposal::Type remove = 3;
|
||||
static constexpr Proposal::Type psk = 4;
|
||||
static constexpr Proposal::Type reinit = 5;
|
||||
static constexpr Proposal::Type external_init = 6;
|
||||
static constexpr Proposal::Type group_context_extensions = 7;
|
||||
|
||||
constexpr ProposalType()
|
||||
: val(invalid)
|
||||
{
|
||||
}
|
||||
|
||||
constexpr ProposalType(Proposal::Type pt)
|
||||
: val(pt)
|
||||
{
|
||||
}
|
||||
|
||||
Proposal::Type val;
|
||||
TLS_SERIALIZABLE(val)
|
||||
};
|
||||
|
||||
enum struct ProposalOrRefType : uint8_t
|
||||
{
|
||||
reserved = 0,
|
||||
value = 1,
|
||||
reference = 2,
|
||||
};
|
||||
|
||||
struct ProposalOrRef
|
||||
{
|
||||
var::variant<Proposal, ProposalRef> content;
|
||||
|
||||
TLS_SERIALIZABLE(content)
|
||||
TLS_TRAITS(tls::variant<ProposalOrRefType>)
|
||||
};
|
||||
|
||||
// struct {
|
||||
// ProposalOrRef proposals<0..2^32-1>;
|
||||
// optional<UpdatePath> path;
|
||||
// } Commit;
|
||||
struct Commit
|
||||
{
|
||||
std::vector<ProposalOrRef> proposals;
|
||||
std::optional<UpdatePath> path;
|
||||
|
||||
// Validate that the commit is acceptable as an external commit, and if so,
|
||||
// produce the public key from the ExternalInit proposal
|
||||
std::optional<bytes> valid_external() const;
|
||||
|
||||
TLS_SERIALIZABLE(proposals, path)
|
||||
};
|
||||
|
||||
// struct {
|
||||
// opaque group_id<0..255>;
|
||||
// uint32 epoch;
|
||||
// uint32 sender;
|
||||
// ContentType content_type;
|
||||
//
|
||||
// select (PublicMessage.content_type) {
|
||||
// case handshake:
|
||||
// GroupOperation operation;
|
||||
// opaque confirmation<0..255>;
|
||||
//
|
||||
// case application:
|
||||
// opaque application_data<0..2^32-1>;
|
||||
// }
|
||||
//
|
||||
// opaque signature<0..2^16-1>;
|
||||
// } PublicMessage;
|
||||
struct ApplicationData
|
||||
{
|
||||
bytes data;
|
||||
TLS_SERIALIZABLE(data)
|
||||
};
|
||||
|
||||
struct GroupContext;
|
||||
|
||||
enum struct WireFormat : uint16_t
|
||||
{
|
||||
reserved = 0,
|
||||
mls_public_message = 1,
|
||||
mls_private_message = 2,
|
||||
mls_welcome = 3,
|
||||
mls_group_info = 4,
|
||||
mls_key_package = 5,
|
||||
};
|
||||
|
||||
enum struct ContentType : uint8_t
|
||||
{
|
||||
invalid = 0,
|
||||
application = 1,
|
||||
proposal = 2,
|
||||
commit = 3,
|
||||
};
|
||||
|
||||
enum struct SenderType : uint8_t
|
||||
{
|
||||
invalid = 0,
|
||||
member = 1,
|
||||
external = 2,
|
||||
new_member_proposal = 3,
|
||||
new_member_commit = 4,
|
||||
};
|
||||
|
||||
struct MemberSender
|
||||
{
|
||||
LeafIndex sender;
|
||||
TLS_SERIALIZABLE(sender);
|
||||
};
|
||||
|
||||
struct ExternalSenderIndex
|
||||
{
|
||||
uint32_t sender_index;
|
||||
TLS_SERIALIZABLE(sender_index)
|
||||
};
|
||||
|
||||
struct NewMemberProposalSender
|
||||
{
|
||||
TLS_SERIALIZABLE()
|
||||
};
|
||||
|
||||
struct NewMemberCommitSender
|
||||
{
|
||||
TLS_SERIALIZABLE()
|
||||
};
|
||||
|
||||
struct Sender
|
||||
{
|
||||
var::variant<MemberSender,
|
||||
ExternalSenderIndex,
|
||||
NewMemberProposalSender,
|
||||
NewMemberCommitSender>
|
||||
sender;
|
||||
|
||||
SenderType sender_type() const;
|
||||
|
||||
TLS_SERIALIZABLE(sender)
|
||||
TLS_TRAITS(tls::variant<SenderType>)
|
||||
};
|
||||
|
||||
///
|
||||
/// MLSMessage and friends
|
||||
///
|
||||
struct GroupKeySource;
|
||||
|
||||
struct GroupContent
|
||||
{
|
||||
using RawContent = var::variant<ApplicationData, Proposal, Commit>;
|
||||
|
||||
bytes group_id;
|
||||
epoch_t epoch;
|
||||
Sender sender;
|
||||
bytes authenticated_data;
|
||||
RawContent content;
|
||||
|
||||
GroupContent() = default;
|
||||
GroupContent(bytes group_id_in,
|
||||
epoch_t epoch_in,
|
||||
Sender sender_in,
|
||||
bytes authenticated_data_in,
|
||||
RawContent content_in);
|
||||
GroupContent(bytes group_id_in,
|
||||
epoch_t epoch_in,
|
||||
Sender sender_in,
|
||||
bytes authenticated_data_in,
|
||||
ContentType content_type);
|
||||
|
||||
ContentType content_type() const;
|
||||
|
||||
TLS_SERIALIZABLE(group_id, epoch, sender, authenticated_data, content)
|
||||
TLS_TRAITS(tls::pass,
|
||||
tls::pass,
|
||||
tls::pass,
|
||||
tls::pass,
|
||||
tls::variant<ContentType>)
|
||||
};
|
||||
|
||||
struct GroupContentAuthData
|
||||
{
|
||||
ContentType content_type = ContentType::invalid;
|
||||
bytes signature;
|
||||
std::optional<bytes> confirmation_tag;
|
||||
|
||||
friend tls::ostream& operator<<(tls::ostream& str,
|
||||
const GroupContentAuthData& obj);
|
||||
friend tls::istream& operator>>(tls::istream& str, GroupContentAuthData& obj);
|
||||
friend bool operator==(const GroupContentAuthData& lhs,
|
||||
const GroupContentAuthData& rhs);
|
||||
};
|
||||
|
||||
struct AuthenticatedContent
|
||||
{
|
||||
WireFormat wire_format;
|
||||
GroupContent content;
|
||||
GroupContentAuthData auth;
|
||||
|
||||
AuthenticatedContent() = default;
|
||||
|
||||
static AuthenticatedContent sign(WireFormat wire_format,
|
||||
GroupContent content,
|
||||
CipherSuite suite,
|
||||
const SignaturePrivateKey& sig_priv,
|
||||
const std::optional<GroupContext>& context);
|
||||
bool verify(CipherSuite suite,
|
||||
const SignaturePublicKey& sig_pub,
|
||||
const std::optional<GroupContext>& context) const;
|
||||
|
||||
bytes confirmed_transcript_hash_input() const;
|
||||
bytes interim_transcript_hash_input() const;
|
||||
|
||||
void set_confirmation_tag(const bytes& confirmation_tag);
|
||||
bool check_confirmation_tag(const bytes& confirmation_tag) const;
|
||||
|
||||
friend tls::ostream& operator<<(tls::ostream& str,
|
||||
const AuthenticatedContent& obj);
|
||||
friend tls::istream& operator>>(tls::istream& str, AuthenticatedContent& obj);
|
||||
friend bool operator==(const AuthenticatedContent& lhs,
|
||||
const AuthenticatedContent& rhs);
|
||||
|
||||
private:
|
||||
AuthenticatedContent(WireFormat wire_format_in, GroupContent content_in);
|
||||
AuthenticatedContent(WireFormat wire_format_in,
|
||||
GroupContent content_in,
|
||||
GroupContentAuthData auth_in);
|
||||
|
||||
bytes to_be_signed(const std::optional<GroupContext>& context) const;
|
||||
|
||||
friend struct PublicMessage;
|
||||
friend struct PrivateMessage;
|
||||
};
|
||||
|
||||
struct ValidatedContent
|
||||
{
|
||||
const AuthenticatedContent& authenticated_content() const;
|
||||
|
||||
friend bool operator==(const ValidatedContent& lhs,
|
||||
const ValidatedContent& rhs);
|
||||
|
||||
private:
|
||||
AuthenticatedContent content_auth;
|
||||
|
||||
ValidatedContent(AuthenticatedContent content_auth_in);
|
||||
|
||||
friend struct PublicMessage;
|
||||
friend struct PrivateMessage;
|
||||
friend class State;
|
||||
};
|
||||
|
||||
struct PublicMessage
|
||||
{
|
||||
PublicMessage() = default;
|
||||
|
||||
bytes get_group_id() const { return content.group_id; }
|
||||
epoch_t get_epoch() const { return content.epoch; }
|
||||
|
||||
static PublicMessage protect(AuthenticatedContent content_auth,
|
||||
CipherSuite suite,
|
||||
const std::optional<bytes>& membership_key,
|
||||
const std::optional<GroupContext>& context);
|
||||
std::optional<ValidatedContent> unprotect(
|
||||
CipherSuite suite,
|
||||
const std::optional<bytes>& membership_key,
|
||||
const std::optional<GroupContext>& context) const;
|
||||
|
||||
bool contains(const AuthenticatedContent& content_auth) const;
|
||||
|
||||
// TODO(RLB) Make this private and expose only to tests
|
||||
AuthenticatedContent authenticated_content() const;
|
||||
|
||||
friend tls::ostream& operator<<(tls::ostream& str, const PublicMessage& obj);
|
||||
friend tls::istream& operator>>(tls::istream& str, PublicMessage& obj);
|
||||
friend bool operator==(const PublicMessage& lhs, const PublicMessage& rhs);
|
||||
friend bool operator!=(const PublicMessage& lhs, const PublicMessage& rhs);
|
||||
|
||||
private:
|
||||
GroupContent content;
|
||||
GroupContentAuthData auth;
|
||||
std::optional<bytes> membership_tag;
|
||||
|
||||
PublicMessage(AuthenticatedContent content_auth);
|
||||
|
||||
bytes membership_mac(CipherSuite suite,
|
||||
const bytes& membership_key,
|
||||
const std::optional<GroupContext>& context) const;
|
||||
};
|
||||
|
||||
struct PrivateMessage
|
||||
{
|
||||
PrivateMessage() = default;
|
||||
|
||||
bytes get_group_id() const { return group_id; }
|
||||
epoch_t get_epoch() const { return epoch; }
|
||||
|
||||
static PrivateMessage protect(AuthenticatedContent content_auth,
|
||||
CipherSuite suite,
|
||||
GroupKeySource& keys,
|
||||
const bytes& sender_data_secret,
|
||||
size_t padding_size);
|
||||
std::optional<ValidatedContent> unprotect(
|
||||
CipherSuite suite,
|
||||
GroupKeySource& keys,
|
||||
const bytes& sender_data_secret) const;
|
||||
|
||||
TLS_SERIALIZABLE(group_id,
|
||||
epoch,
|
||||
content_type,
|
||||
authenticated_data,
|
||||
encrypted_sender_data,
|
||||
ciphertext)
|
||||
|
||||
private:
|
||||
bytes group_id;
|
||||
epoch_t epoch;
|
||||
ContentType content_type;
|
||||
bytes authenticated_data;
|
||||
bytes encrypted_sender_data;
|
||||
bytes ciphertext;
|
||||
|
||||
PrivateMessage(GroupContent content,
|
||||
bytes encrypted_sender_data_in,
|
||||
bytes ciphertext_in);
|
||||
};
|
||||
|
||||
struct MLSMessage
|
||||
{
|
||||
ProtocolVersion version = ProtocolVersion::mls10;
|
||||
var::variant<PublicMessage, PrivateMessage, Welcome, GroupInfo, KeyPackage>
|
||||
message;
|
||||
|
||||
bytes group_id() const;
|
||||
epoch_t epoch() const;
|
||||
WireFormat wire_format() const;
|
||||
|
||||
MLSMessage() = default;
|
||||
MLSMessage(PublicMessage public_message);
|
||||
MLSMessage(PrivateMessage private_message);
|
||||
MLSMessage(Welcome welcome);
|
||||
MLSMessage(GroupInfo group_info);
|
||||
MLSMessage(KeyPackage key_package);
|
||||
|
||||
TLS_SERIALIZABLE(version, message)
|
||||
TLS_TRAITS(tls::pass, tls::variant<WireFormat>)
|
||||
};
|
||||
|
||||
MLSMessage
|
||||
external_proposal(CipherSuite suite,
|
||||
const bytes& group_id,
|
||||
epoch_t epoch,
|
||||
const Proposal& proposal,
|
||||
uint32_t signer_index,
|
||||
const SignaturePrivateKey& sig_priv);
|
||||
|
||||
} // namespace mlspp
|
||||
|
||||
namespace mlspp::tls {
|
||||
|
||||
TLS_VARIANT_MAP(mlspp::PSKType, mlspp::ExternalPSK, external)
|
||||
TLS_VARIANT_MAP(mlspp::PSKType,
|
||||
mlspp::ResumptionPSK,
|
||||
resumption)
|
||||
|
||||
TLS_VARIANT_MAP(mlspp::ProposalOrRefType,
|
||||
mlspp::Proposal,
|
||||
value)
|
||||
TLS_VARIANT_MAP(mlspp::ProposalOrRefType,
|
||||
mlspp::ProposalRef,
|
||||
reference)
|
||||
|
||||
TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::Add, add)
|
||||
TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::Update, update)
|
||||
TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::Remove, remove)
|
||||
TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::PreSharedKey, psk)
|
||||
TLS_VARIANT_MAP(mlspp::ProposalType, mlspp::ReInit, reinit)
|
||||
TLS_VARIANT_MAP(mlspp::ProposalType,
|
||||
mlspp::ExternalInit,
|
||||
external_init)
|
||||
TLS_VARIANT_MAP(mlspp::ProposalType,
|
||||
mlspp::GroupContextExtensions,
|
||||
group_context_extensions)
|
||||
|
||||
TLS_VARIANT_MAP(mlspp::ContentType,
|
||||
mlspp::ApplicationData,
|
||||
application)
|
||||
TLS_VARIANT_MAP(mlspp::ContentType, mlspp::Proposal, proposal)
|
||||
TLS_VARIANT_MAP(mlspp::ContentType, mlspp::Commit, commit)
|
||||
|
||||
TLS_VARIANT_MAP(mlspp::SenderType, mlspp::MemberSender, member)
|
||||
TLS_VARIANT_MAP(mlspp::SenderType,
|
||||
mlspp::ExternalSenderIndex,
|
||||
external)
|
||||
TLS_VARIANT_MAP(mlspp::SenderType,
|
||||
mlspp::NewMemberProposalSender,
|
||||
new_member_proposal)
|
||||
TLS_VARIANT_MAP(mlspp::SenderType,
|
||||
mlspp::NewMemberCommitSender,
|
||||
new_member_commit)
|
||||
|
||||
TLS_VARIANT_MAP(mlspp::WireFormat,
|
||||
mlspp::PublicMessage,
|
||||
mls_public_message)
|
||||
TLS_VARIANT_MAP(mlspp::WireFormat,
|
||||
mlspp::PrivateMessage,
|
||||
mls_private_message)
|
||||
TLS_VARIANT_MAP(mlspp::WireFormat, mlspp::Welcome, mls_welcome)
|
||||
TLS_VARIANT_MAP(mlspp::WireFormat,
|
||||
mlspp::GroupInfo,
|
||||
mls_group_info)
|
||||
TLS_VARIANT_MAP(mlspp::WireFormat,
|
||||
mlspp::KeyPackage,
|
||||
mls_key_package)
|
||||
|
||||
} // namespace mlspp::tls
|
||||
98
DPP/mlspp/include/mls/session.h
Executable file
98
DPP/mlspp/include/mls/session.h
Executable file
@@ -0,0 +1,98 @@
|
||||
#pragma once
|
||||
|
||||
#include <mls/common.h>
|
||||
#include <mls/core_types.h>
|
||||
#include <mls/credential.h>
|
||||
#include <mls/crypto.h>
|
||||
#include <mls/state.h>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
class PendingJoin;
|
||||
class Session;
|
||||
|
||||
class Client
|
||||
{
|
||||
public:
|
||||
Client(CipherSuite suite_in,
|
||||
SignaturePrivateKey sig_priv_in,
|
||||
Credential cred_in);
|
||||
|
||||
Session begin_session(const bytes& group_id) const;
|
||||
|
||||
PendingJoin start_join() const;
|
||||
|
||||
private:
|
||||
const CipherSuite suite;
|
||||
const SignaturePrivateKey sig_priv;
|
||||
const Credential cred;
|
||||
};
|
||||
|
||||
class PendingJoin
|
||||
{
|
||||
public:
|
||||
PendingJoin(PendingJoin&& other) noexcept;
|
||||
PendingJoin& operator=(PendingJoin&& other) noexcept;
|
||||
~PendingJoin();
|
||||
bytes key_package() const;
|
||||
Session complete(const bytes& welcome) const;
|
||||
|
||||
private:
|
||||
struct Inner;
|
||||
std::unique_ptr<Inner> inner;
|
||||
|
||||
PendingJoin(Inner* inner);
|
||||
friend class Client;
|
||||
};
|
||||
|
||||
class Session
|
||||
{
|
||||
public:
|
||||
Session(Session&& other) noexcept;
|
||||
Session& operator=(Session&& other) noexcept;
|
||||
~Session();
|
||||
|
||||
// Settings
|
||||
void encrypt_handshake(bool enabled);
|
||||
|
||||
// Message producers
|
||||
bytes add(const bytes& key_package_data);
|
||||
bytes update();
|
||||
bytes remove(uint32_t index);
|
||||
std::tuple<bytes, bytes> commit(const bytes& proposal);
|
||||
std::tuple<bytes, bytes> commit(const std::vector<bytes>& proposals);
|
||||
std::tuple<bytes, bytes> commit();
|
||||
|
||||
// Message consumers
|
||||
bool handle(const bytes& handshake_data);
|
||||
|
||||
// Information about the current state
|
||||
epoch_t epoch() const;
|
||||
LeafIndex index() const;
|
||||
CipherSuite cipher_suite() const;
|
||||
const ExtensionList& extensions() const;
|
||||
const TreeKEMPublicKey& tree() const;
|
||||
bytes do_export(const std::string& label,
|
||||
const bytes& context,
|
||||
size_t size) const;
|
||||
GroupInfo group_info() const;
|
||||
std::vector<LeafNode> roster() const;
|
||||
bytes epoch_authenticator() const;
|
||||
|
||||
// Application message protection
|
||||
bytes protect(const bytes& plaintext);
|
||||
bytes unprotect(const bytes& ciphertext);
|
||||
|
||||
protected:
|
||||
struct Inner;
|
||||
std::unique_ptr<Inner> inner;
|
||||
|
||||
Session(Inner* inner);
|
||||
friend class Client;
|
||||
friend class PendingJoin;
|
||||
|
||||
friend bool operator==(const Session& lhs, const Session& rhs);
|
||||
friend bool operator!=(const Session& lhs, const Session& rhs);
|
||||
};
|
||||
|
||||
} // namespace mlspp
|
||||
431
DPP/mlspp/include/mls/state.h
Executable file
431
DPP/mlspp/include/mls/state.h
Executable file
@@ -0,0 +1,431 @@
|
||||
#pragma once
|
||||
|
||||
#include "mls/crypto.h"
|
||||
#include "mls/key_schedule.h"
|
||||
#include "mls/messages.h"
|
||||
#include "mls/treekem.h"
|
||||
#include <list>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
// Index into the session roster
|
||||
struct RosterIndex : public UInt32
|
||||
{
|
||||
using UInt32::UInt32;
|
||||
};
|
||||
|
||||
struct CommitOpts
|
||||
{
|
||||
std::vector<Proposal> extra_proposals;
|
||||
bool inline_tree;
|
||||
bool force_path;
|
||||
LeafNodeOptions leaf_node_opts;
|
||||
};
|
||||
|
||||
struct MessageOpts
|
||||
{
|
||||
bool encrypt = false;
|
||||
bytes authenticated_data;
|
||||
size_t padding_size = 0;
|
||||
};
|
||||
|
||||
class State
|
||||
{
|
||||
public:
|
||||
///
|
||||
/// Constructors
|
||||
///
|
||||
|
||||
// Initialize an empty group
|
||||
State(bytes group_id,
|
||||
CipherSuite suite,
|
||||
HPKEPrivateKey enc_priv,
|
||||
SignaturePrivateKey sig_priv,
|
||||
const LeafNode& leaf_node,
|
||||
ExtensionList extensions);
|
||||
|
||||
// Initialize a group from a Welcome
|
||||
State(const HPKEPrivateKey& init_priv,
|
||||
HPKEPrivateKey leaf_priv,
|
||||
SignaturePrivateKey sig_priv,
|
||||
const KeyPackage& key_package,
|
||||
const Welcome& welcome,
|
||||
const std::optional<TreeKEMPublicKey>& tree,
|
||||
std::map<bytes, bytes> psks);
|
||||
|
||||
// Join a group from outside
|
||||
// XXX(RLB) To be fully general, we would need a few more options here, e.g.,
|
||||
// whether to include PSKs or evict our prior appearance.
|
||||
static std::tuple<MLSMessage, State> external_join(
|
||||
const bytes& leaf_secret,
|
||||
SignaturePrivateKey sig_priv,
|
||||
const KeyPackage& key_package,
|
||||
const GroupInfo& group_info,
|
||||
const std::optional<TreeKEMPublicKey>& tree,
|
||||
const MessageOpts& msg_opts,
|
||||
std::optional<LeafIndex> remove_prior,
|
||||
const std::map<bytes, bytes>& psks);
|
||||
|
||||
// Propose that a new member be added a group
|
||||
static MLSMessage new_member_add(const bytes& group_id,
|
||||
epoch_t epoch,
|
||||
const KeyPackage& new_member,
|
||||
const SignaturePrivateKey& sig_priv);
|
||||
|
||||
///
|
||||
/// Message factories
|
||||
///
|
||||
|
||||
Proposal add_proposal(const KeyPackage& key_package) const;
|
||||
Proposal update_proposal(HPKEPrivateKey leaf_priv,
|
||||
const LeafNodeOptions& opts);
|
||||
Proposal remove_proposal(RosterIndex index) const;
|
||||
Proposal remove_proposal(LeafIndex removed) const;
|
||||
Proposal group_context_extensions_proposal(ExtensionList exts) const;
|
||||
Proposal pre_shared_key_proposal(const bytes& external_psk_id) const;
|
||||
Proposal pre_shared_key_proposal(const bytes& group_id, epoch_t epoch) const;
|
||||
static Proposal reinit_proposal(bytes group_id,
|
||||
ProtocolVersion version,
|
||||
CipherSuite cipher_suite,
|
||||
ExtensionList extensions);
|
||||
|
||||
MLSMessage add(const KeyPackage& key_package, const MessageOpts& msg_opts);
|
||||
MLSMessage update(HPKEPrivateKey leaf_priv,
|
||||
const LeafNodeOptions& opts,
|
||||
const MessageOpts& msg_opts);
|
||||
MLSMessage remove(RosterIndex index, const MessageOpts& msg_opts);
|
||||
MLSMessage remove(LeafIndex removed, const MessageOpts& msg_opts);
|
||||
MLSMessage group_context_extensions(ExtensionList exts,
|
||||
const MessageOpts& msg_opts);
|
||||
MLSMessage pre_shared_key(const bytes& external_psk_id,
|
||||
const MessageOpts& msg_opts);
|
||||
MLSMessage pre_shared_key(const bytes& group_id,
|
||||
epoch_t epoch,
|
||||
const MessageOpts& msg_opts);
|
||||
MLSMessage reinit(bytes group_id,
|
||||
ProtocolVersion version,
|
||||
CipherSuite cipher_suite,
|
||||
ExtensionList extensions,
|
||||
const MessageOpts& msg_opts);
|
||||
|
||||
std::tuple<MLSMessage, Welcome, State> commit(
|
||||
const bytes& leaf_secret,
|
||||
const std::optional<CommitOpts>& opts,
|
||||
const MessageOpts& msg_opts);
|
||||
|
||||
///
|
||||
/// Generic handshake message handlers
|
||||
///
|
||||
std::optional<State> handle(const MLSMessage& msg);
|
||||
std::optional<State> handle(const MLSMessage& msg,
|
||||
std::optional<State> cached_state);
|
||||
|
||||
std::optional<State> handle(const ValidatedContent& content_auth);
|
||||
std::optional<State> handle(const ValidatedContent& content_auth,
|
||||
std::optional<State> cached_state);
|
||||
|
||||
///
|
||||
/// PSK management
|
||||
///
|
||||
void add_resumption_psk(const bytes& group_id, epoch_t epoch, bytes secret);
|
||||
void remove_resumption_psk(const bytes& group_id, epoch_t epoch);
|
||||
void add_external_psk(const bytes& id, const bytes& secret);
|
||||
void remove_external_psk(const bytes& id);
|
||||
|
||||
///
|
||||
/// Accessors
|
||||
///
|
||||
const bytes& group_id() const { return _group_id; }
|
||||
epoch_t epoch() const { return _epoch; }
|
||||
LeafIndex index() const { return _index; }
|
||||
CipherSuite cipher_suite() const { return _suite; }
|
||||
const ExtensionList& extensions() const { return _extensions; }
|
||||
const TreeKEMPublicKey& tree() const { return _tree; }
|
||||
const bytes& resumption_psk() const { return _key_schedule.resumption_psk; }
|
||||
|
||||
bytes do_export(const std::string& label,
|
||||
const bytes& context,
|
||||
size_t size) const;
|
||||
GroupInfo group_info(bool inline_tree) const;
|
||||
|
||||
// Ordered list of credentials from non-blank leaves
|
||||
std::vector<LeafNode> roster() const;
|
||||
|
||||
bytes epoch_authenticator() const;
|
||||
|
||||
///
|
||||
/// Unwrap messages so that applications can inspect them
|
||||
///
|
||||
ValidatedContent unwrap(const MLSMessage& msg);
|
||||
|
||||
///
|
||||
/// Application encryption and decryption
|
||||
///
|
||||
MLSMessage protect(const bytes& authenticated_data,
|
||||
const bytes& pt,
|
||||
size_t padding_size);
|
||||
std::tuple<bytes, bytes> unprotect(const MLSMessage& ct);
|
||||
|
||||
// Assemble a group context for this state
|
||||
GroupContext group_context() const;
|
||||
|
||||
// Subgroup branching
|
||||
std::tuple<State, Welcome> create_branch(
|
||||
bytes group_id,
|
||||
HPKEPrivateKey enc_priv,
|
||||
SignaturePrivateKey sig_priv,
|
||||
const LeafNode& leaf_node,
|
||||
ExtensionList extensions,
|
||||
const std::vector<KeyPackage>& key_packages,
|
||||
const bytes& leaf_secret,
|
||||
const CommitOpts& commit_opts) const;
|
||||
State handle_branch(const HPKEPrivateKey& init_priv,
|
||||
HPKEPrivateKey enc_priv,
|
||||
SignaturePrivateKey sig_priv,
|
||||
const KeyPackage& key_package,
|
||||
const Welcome& welcome,
|
||||
const std::optional<TreeKEMPublicKey>& tree) const;
|
||||
|
||||
// Reinitialization
|
||||
struct Tombstone
|
||||
{
|
||||
std::tuple<State, Welcome> create_welcome(
|
||||
HPKEPrivateKey enc_priv,
|
||||
SignaturePrivateKey sig_priv,
|
||||
const LeafNode& leaf_node,
|
||||
const std::vector<KeyPackage>& key_packages,
|
||||
const bytes& leaf_secret,
|
||||
const CommitOpts& commit_opts) const;
|
||||
State handle_welcome(const HPKEPrivateKey& init_priv,
|
||||
HPKEPrivateKey enc_priv,
|
||||
SignaturePrivateKey sig_priv,
|
||||
const KeyPackage& key_package,
|
||||
const Welcome& welcome,
|
||||
const std::optional<TreeKEMPublicKey>& tree) const;
|
||||
|
||||
TLS_SERIALIZABLE(prior_group_id, prior_epoch, resumption_psk, reinit);
|
||||
|
||||
const bytes epoch_authenticator;
|
||||
const ReInit reinit;
|
||||
|
||||
private:
|
||||
Tombstone(const State& state_in, ReInit reinit_in);
|
||||
|
||||
bytes prior_group_id;
|
||||
epoch_t prior_epoch;
|
||||
bytes resumption_psk;
|
||||
|
||||
friend class State;
|
||||
};
|
||||
|
||||
std::tuple<Tombstone, MLSMessage> reinit_commit(
|
||||
const bytes& leaf_secret,
|
||||
const std::optional<CommitOpts>& opts,
|
||||
const MessageOpts& msg_opts);
|
||||
Tombstone handle_reinit_commit(const MLSMessage& commit);
|
||||
|
||||
protected:
|
||||
// Shared confirmed state
|
||||
// XXX(rlb@ipv.sx): Can these be made const?
|
||||
CipherSuite _suite;
|
||||
bytes _group_id;
|
||||
epoch_t _epoch;
|
||||
TreeKEMPublicKey _tree;
|
||||
TreeKEMPrivateKey _tree_priv;
|
||||
TranscriptHash _transcript_hash;
|
||||
ExtensionList _extensions;
|
||||
|
||||
// Shared secret state
|
||||
KeyScheduleEpoch _key_schedule;
|
||||
GroupKeySource _keys;
|
||||
|
||||
// Per-participant state
|
||||
LeafIndex _index;
|
||||
SignaturePrivateKey _identity_priv;
|
||||
|
||||
// Storage for PSKs
|
||||
std::map<bytes, bytes> _external_psks;
|
||||
|
||||
using EpochRef = std::tuple<bytes, epoch_t>;
|
||||
std::map<EpochRef, bytes> _resumption_psks;
|
||||
|
||||
// Cache of Proposals and update secrets
|
||||
struct CachedProposal
|
||||
{
|
||||
ProposalRef ref;
|
||||
Proposal proposal;
|
||||
std::optional<LeafIndex> sender;
|
||||
};
|
||||
std::list<CachedProposal> _pending_proposals;
|
||||
|
||||
struct CachedUpdate
|
||||
{
|
||||
HPKEPrivateKey update_priv;
|
||||
Update proposal;
|
||||
};
|
||||
std::optional<CachedUpdate> _cached_update;
|
||||
|
||||
// Assemble a preliminary, unjoined group state
|
||||
State(SignaturePrivateKey sig_priv,
|
||||
const GroupInfo& group_info,
|
||||
const std::optional<TreeKEMPublicKey>& tree);
|
||||
|
||||
// Assemble a group from a Welcome, allowing for resumption PSKs
|
||||
State(const HPKEPrivateKey& init_priv,
|
||||
HPKEPrivateKey leaf_priv,
|
||||
SignaturePrivateKey sig_priv,
|
||||
const KeyPackage& key_package,
|
||||
const Welcome& welcome,
|
||||
const std::optional<TreeKEMPublicKey>& tree,
|
||||
std::map<bytes, bytes> external_psks,
|
||||
std::map<EpochRef, bytes> resumption_psks);
|
||||
|
||||
// Import a tree from an externally-provided tree or an extension
|
||||
TreeKEMPublicKey import_tree(const bytes& tree_hash,
|
||||
const std::optional<TreeKEMPublicKey>& external,
|
||||
const ExtensionList& extensions);
|
||||
bool validate_tree() const;
|
||||
|
||||
// Form a commit, covering all the cases with slightly different validation
|
||||
// rules:
|
||||
// * Normal
|
||||
// * External
|
||||
// * Branch
|
||||
// * Reinit
|
||||
struct NormalCommitParams
|
||||
{};
|
||||
|
||||
struct ExternalCommitParams
|
||||
{
|
||||
KeyPackage joiner_key_package;
|
||||
bytes force_init_secret;
|
||||
};
|
||||
|
||||
struct RestartCommitParams
|
||||
{
|
||||
ResumptionPSKUsage allowed_usage;
|
||||
};
|
||||
|
||||
struct ReInitCommitParams
|
||||
{};
|
||||
|
||||
using CommitParams = var::variant<NormalCommitParams,
|
||||
ExternalCommitParams,
|
||||
RestartCommitParams,
|
||||
ReInitCommitParams>;
|
||||
|
||||
std::tuple<MLSMessage, Welcome, State> commit(
|
||||
const bytes& leaf_secret,
|
||||
const std::optional<CommitOpts>& opts,
|
||||
const MessageOpts& msg_opts,
|
||||
CommitParams params);
|
||||
|
||||
std::optional<State> handle(
|
||||
const MLSMessage& msg,
|
||||
std::optional<State> cached_state,
|
||||
const std::optional<CommitParams>& expected_params);
|
||||
std::optional<State> handle(
|
||||
const ValidatedContent& val_content,
|
||||
std::optional<State> cached_state,
|
||||
const std::optional<CommitParams>& expected_params);
|
||||
|
||||
// Create an MLSMessage encapsulating some content
|
||||
template<typename Inner>
|
||||
AuthenticatedContent sign(const Sender& sender,
|
||||
Inner&& content,
|
||||
const bytes& authenticated_data,
|
||||
bool encrypt) const;
|
||||
|
||||
MLSMessage protect(AuthenticatedContent&& content_auth, size_t padding_size);
|
||||
|
||||
template<typename Inner>
|
||||
MLSMessage protect_full(Inner&& content, const MessageOpts& msg_opts);
|
||||
|
||||
// Apply the changes requested by various messages
|
||||
LeafIndex apply(const Add& add);
|
||||
void apply(LeafIndex target, const Update& update);
|
||||
void apply(LeafIndex target,
|
||||
const Update& update,
|
||||
const HPKEPrivateKey& leaf_priv);
|
||||
LeafIndex apply(const Remove& remove);
|
||||
void apply(const GroupContextExtensions& gce);
|
||||
std::vector<LeafIndex> apply(const std::vector<CachedProposal>& proposals,
|
||||
Proposal::Type required_type);
|
||||
std::tuple<std::vector<LeafIndex>, std::vector<PSKWithSecret>> apply(
|
||||
const std::vector<CachedProposal>& proposals);
|
||||
|
||||
// Verify that a specific key package or all members support a given set of
|
||||
// extensions
|
||||
bool extensions_supported(const ExtensionList& exts) const;
|
||||
|
||||
// Extract proposals and PSKs from cache
|
||||
void cache_proposal(AuthenticatedContent content_auth);
|
||||
std::optional<CachedProposal> resolve(
|
||||
const ProposalOrRef& id,
|
||||
std::optional<LeafIndex> sender_index) const;
|
||||
std::vector<CachedProposal> must_resolve(
|
||||
const std::vector<ProposalOrRef>& ids,
|
||||
std::optional<LeafIndex> sender_index) const;
|
||||
|
||||
std::vector<PSKWithSecret> resolve(
|
||||
const std::vector<PreSharedKeyID>& psks) const;
|
||||
|
||||
// Check properties of proposals
|
||||
bool valid(const LeafNode& leaf_node,
|
||||
LeafNodeSource required_source,
|
||||
std::optional<LeafIndex> index) const;
|
||||
bool valid(const KeyPackage& key_package) const;
|
||||
bool valid(const Add& add) const;
|
||||
bool valid(LeafIndex sender, const Update& update) const;
|
||||
bool valid(const Remove& remove) const;
|
||||
bool valid(const PreSharedKey& psk) const;
|
||||
static bool valid(const ReInit& reinit);
|
||||
bool valid(const ExternalInit& external_init) const;
|
||||
bool valid(const GroupContextExtensions& gce) const;
|
||||
bool valid(std::optional<LeafIndex> sender, const Proposal& proposal) const;
|
||||
|
||||
bool valid(const std::vector<CachedProposal>& proposals,
|
||||
LeafIndex commit_sender,
|
||||
const CommitParams& params) const;
|
||||
bool valid_normal(const std::vector<CachedProposal>& proposals,
|
||||
LeafIndex commit_sender) const;
|
||||
bool valid_external(const std::vector<CachedProposal>& proposals) const;
|
||||
static bool valid_reinit(const std::vector<CachedProposal>& proposals);
|
||||
static bool valid_restart(const std::vector<CachedProposal>& proposals,
|
||||
ResumptionPSKUsage allowed_usage);
|
||||
|
||||
static bool valid_external_proposal_type(const Proposal::Type proposal_type);
|
||||
|
||||
CommitParams infer_commit_type(
|
||||
const std::optional<LeafIndex>& sender,
|
||||
const std::vector<CachedProposal>& proposals,
|
||||
const std::optional<CommitParams>& expected_params) const;
|
||||
static bool path_required(const std::vector<CachedProposal>& proposals);
|
||||
|
||||
// Compare the **shared** attributes of the states
|
||||
friend bool operator==(const State& lhs, const State& rhs);
|
||||
friend bool operator!=(const State& lhs, const State& rhs);
|
||||
|
||||
// Derive and set the secrets for an epoch, given some new entropy
|
||||
void update_epoch_secrets(const bytes& commit_secret,
|
||||
const std::vector<PSKWithSecret>& psks,
|
||||
const std::optional<bytes>& force_init_secret);
|
||||
|
||||
// Signature verification over a handshake message
|
||||
bool verify_internal(const AuthenticatedContent& content_auth) const;
|
||||
bool verify_external(const AuthenticatedContent& content_auth) const;
|
||||
bool verify_new_member_proposal(
|
||||
const AuthenticatedContent& content_auth) const;
|
||||
bool verify_new_member_commit(const AuthenticatedContent& content_auth) const;
|
||||
bool verify(const AuthenticatedContent& content_auth) const;
|
||||
|
||||
// Convert a Roster entry into LeafIndex
|
||||
LeafIndex leaf_for_roster_entry(RosterIndex index) const;
|
||||
|
||||
// Create a draft successor state
|
||||
State successor() const;
|
||||
};
|
||||
|
||||
} // namespace mlspp
|
||||
107
DPP/mlspp/include/mls/tree_math.h
Executable file
107
DPP/mlspp/include/mls/tree_math.h
Executable file
@@ -0,0 +1,107 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <tls/tls_syntax.h>
|
||||
#include <vector>
|
||||
|
||||
// The below functions provide the index calculus for the tree
|
||||
// structures used in MLS. They are premised on a "flat"
|
||||
// representation of a balanced binary tree. Leaf nodes are
|
||||
// even-numbered nodes, with the n-th leaf at 2*n. Intermediate
|
||||
// nodes are held in odd-numbered nodes. For example, a 11-element
|
||||
// tree has the following structure:
|
||||
//
|
||||
// X
|
||||
// X
|
||||
// X X X
|
||||
// X X X X X
|
||||
// X X X X X X X X X X X
|
||||
// 0 1 2 3 4 5 6 7 8 9 a b c d e f 10 11 12 13 14
|
||||
//
|
||||
// This allows us to compute relationships between tree nodes simply
|
||||
// by manipulating indices, rather than having to maintain
|
||||
// complicated structures in memory, even for partial trees. (The
|
||||
// storage for a tree can just be a map[int]Node dictionary or an
|
||||
// array.) The basic rule is that the high-order bits of parent and
|
||||
// child nodes have the following relation:
|
||||
//
|
||||
// 01x = <00x, 10x>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
// Index types go in the overall namespace
|
||||
// XXX(rlb@ipv.sx): Seems like this stuff can probably get
|
||||
// simplified down a fair bit.
|
||||
struct UInt32
|
||||
{
|
||||
uint32_t val;
|
||||
|
||||
UInt32()
|
||||
: val(0)
|
||||
{
|
||||
}
|
||||
|
||||
explicit UInt32(uint32_t val_in)
|
||||
: val(val_in)
|
||||
{
|
||||
}
|
||||
|
||||
TLS_SERIALIZABLE(val)
|
||||
};
|
||||
|
||||
struct NodeCount;
|
||||
|
||||
struct LeafCount : public UInt32
|
||||
{
|
||||
using UInt32::UInt32;
|
||||
explicit LeafCount(const NodeCount w);
|
||||
|
||||
static LeafCount full(const LeafCount n);
|
||||
};
|
||||
|
||||
struct NodeCount : public UInt32
|
||||
{
|
||||
using UInt32::UInt32;
|
||||
explicit NodeCount(const LeafCount n);
|
||||
};
|
||||
|
||||
struct NodeIndex;
|
||||
|
||||
struct LeafIndex : public UInt32
|
||||
{
|
||||
using UInt32::UInt32;
|
||||
explicit LeafIndex(const NodeIndex x);
|
||||
bool operator<(const LeafIndex other) const { return val < other.val; }
|
||||
bool operator<(const LeafCount other) const { return val < other.val; }
|
||||
|
||||
NodeIndex ancestor(LeafIndex other) const;
|
||||
};
|
||||
|
||||
struct NodeIndex : public UInt32
|
||||
{
|
||||
using UInt32::UInt32;
|
||||
explicit NodeIndex(const LeafIndex x);
|
||||
bool operator<(const NodeIndex other) const { return val < other.val; }
|
||||
bool operator<(const NodeCount other) const { return val < other.val; }
|
||||
|
||||
static NodeIndex root(LeafCount n);
|
||||
|
||||
bool is_leaf() const;
|
||||
bool is_below(NodeIndex other) const;
|
||||
|
||||
NodeIndex left() const;
|
||||
NodeIndex right() const;
|
||||
NodeIndex parent() const;
|
||||
NodeIndex sibling() const;
|
||||
|
||||
// Returns the sibling of this node "relative to this ancestor" -- the child
|
||||
// of `ancestor` that is not in the direct path of this node.
|
||||
NodeIndex sibling(NodeIndex ancestor) const;
|
||||
|
||||
std::vector<NodeIndex> dirpath(LeafCount n);
|
||||
std::vector<NodeIndex> copath(LeafCount n);
|
||||
|
||||
uint32_t level() const;
|
||||
};
|
||||
|
||||
} // namespace mlspp
|
||||
255
DPP/mlspp/include/mls/treekem.h
Executable file
255
DPP/mlspp/include/mls/treekem.h
Executable file
@@ -0,0 +1,255 @@
|
||||
#pragma once
|
||||
|
||||
#include "mls/common.h"
|
||||
#include "mls/core_types.h"
|
||||
#include "mls/crypto.h"
|
||||
#include "mls/tree_math.h"
|
||||
#include <tls/tls_syntax.h>
|
||||
|
||||
#define ENABLE_TREE_DUMP 1
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
enum struct NodeType : uint8_t
|
||||
{
|
||||
reserved = 0x00,
|
||||
leaf = 0x01,
|
||||
parent = 0x02,
|
||||
};
|
||||
|
||||
struct Node
|
||||
{
|
||||
var::variant<LeafNode, ParentNode> node;
|
||||
|
||||
const HPKEPublicKey& public_key() const;
|
||||
std::optional<bytes> parent_hash() const;
|
||||
|
||||
TLS_SERIALIZABLE(node)
|
||||
TLS_TRAITS(tls::variant<NodeType>)
|
||||
};
|
||||
|
||||
struct OptionalNode
|
||||
{
|
||||
std::optional<Node> node;
|
||||
|
||||
bool blank() const { return !node.has_value(); }
|
||||
bool leaf() const
|
||||
{
|
||||
return !blank() && var::holds_alternative<LeafNode>(opt::get(node).node);
|
||||
}
|
||||
|
||||
LeafNode& leaf_node() { return var::get<LeafNode>(opt::get(node).node); }
|
||||
|
||||
const LeafNode& leaf_node() const
|
||||
{
|
||||
return var::get<LeafNode>(opt::get(node).node);
|
||||
}
|
||||
|
||||
ParentNode& parent_node()
|
||||
{
|
||||
return var::get<ParentNode>(opt::get(node).node);
|
||||
}
|
||||
|
||||
const ParentNode& parent_node() const
|
||||
{
|
||||
return var::get<ParentNode>(opt::get(node).node);
|
||||
}
|
||||
|
||||
TLS_SERIALIZABLE(node)
|
||||
};
|
||||
|
||||
struct TreeKEMPublicKey;
|
||||
|
||||
struct TreeKEMPrivateKey
|
||||
{
|
||||
CipherSuite suite;
|
||||
LeafIndex index;
|
||||
bytes update_secret;
|
||||
std::map<NodeIndex, bytes> path_secrets;
|
||||
std::map<NodeIndex, HPKEPrivateKey> private_key_cache;
|
||||
|
||||
static TreeKEMPrivateKey solo(CipherSuite suite,
|
||||
LeafIndex index,
|
||||
HPKEPrivateKey leaf_priv);
|
||||
static TreeKEMPrivateKey create(const TreeKEMPublicKey& pub,
|
||||
LeafIndex from,
|
||||
const bytes& leaf_secret);
|
||||
static TreeKEMPrivateKey joiner(const TreeKEMPublicKey& pub,
|
||||
LeafIndex index,
|
||||
HPKEPrivateKey leaf_priv,
|
||||
NodeIndex intersect,
|
||||
const std::optional<bytes>& path_secret);
|
||||
|
||||
void set_leaf_priv(HPKEPrivateKey priv);
|
||||
std::tuple<NodeIndex, bytes, bool> shared_path_secret(LeafIndex to) const;
|
||||
|
||||
bool have_private_key(NodeIndex n) const;
|
||||
std::optional<HPKEPrivateKey> private_key(NodeIndex n);
|
||||
std::optional<HPKEPrivateKey> private_key(NodeIndex n) const;
|
||||
|
||||
void decap(LeafIndex from,
|
||||
const TreeKEMPublicKey& pub,
|
||||
const bytes& context,
|
||||
const UpdatePath& path,
|
||||
const std::vector<LeafIndex>& except);
|
||||
|
||||
void truncate(LeafCount size);
|
||||
|
||||
bool consistent(const TreeKEMPrivateKey& other) const;
|
||||
bool consistent(const TreeKEMPublicKey& other) const;
|
||||
|
||||
#if ENABLE_TREE_DUMP
|
||||
void dump() const;
|
||||
#endif
|
||||
|
||||
// TODO(RLB) Make this private but exposed to test vectors
|
||||
void implant(const TreeKEMPublicKey& pub,
|
||||
NodeIndex start,
|
||||
const bytes& path_secret);
|
||||
};
|
||||
|
||||
struct TreeKEMPublicKey
|
||||
{
|
||||
CipherSuite suite;
|
||||
LeafCount size{ 0 };
|
||||
std::vector<OptionalNode> nodes;
|
||||
|
||||
explicit TreeKEMPublicKey(CipherSuite suite);
|
||||
|
||||
TreeKEMPublicKey() = default;
|
||||
TreeKEMPublicKey(const TreeKEMPublicKey& other) = default;
|
||||
TreeKEMPublicKey(TreeKEMPublicKey&& other) = default;
|
||||
TreeKEMPublicKey& operator=(const TreeKEMPublicKey& other) = default;
|
||||
TreeKEMPublicKey& operator=(TreeKEMPublicKey&& other) = default;
|
||||
|
||||
LeafIndex allocate_leaf();
|
||||
LeafIndex add_leaf(const LeafNode& leaf);
|
||||
void update_leaf(LeafIndex index, const LeafNode& leaf);
|
||||
void blank_path(LeafIndex index);
|
||||
|
||||
TreeKEMPrivateKey update(LeafIndex from,
|
||||
const bytes& leaf_secret,
|
||||
const bytes& group_id,
|
||||
const SignaturePrivateKey& sig_priv,
|
||||
const LeafNodeOptions& opts);
|
||||
UpdatePath encap(const TreeKEMPrivateKey& priv,
|
||||
const bytes& context,
|
||||
const std::vector<LeafIndex>& except) const;
|
||||
|
||||
void merge(LeafIndex from, const UpdatePath& path);
|
||||
void set_hash_all();
|
||||
const bytes& get_hash(NodeIndex index);
|
||||
bytes root_hash() const;
|
||||
|
||||
bool parent_hash_valid(LeafIndex from, const UpdatePath& path) const;
|
||||
bool parent_hash_valid() const;
|
||||
|
||||
bool has_leaf(LeafIndex index) const;
|
||||
std::optional<LeafIndex> find(const LeafNode& leaf) const;
|
||||
std::optional<LeafNode> leaf_node(LeafIndex index) const;
|
||||
std::vector<NodeIndex> resolve(NodeIndex index) const;
|
||||
|
||||
template<typename UnaryPredicate>
|
||||
bool all_leaves(const UnaryPredicate& pred) const
|
||||
{
|
||||
for (LeafIndex i{ 0 }; i < size; i.val++) {
|
||||
const auto& node = node_at(i);
|
||||
if (node.blank()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!pred(i, node.leaf_node())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename UnaryPredicate>
|
||||
bool any_leaf(const UnaryPredicate& pred) const
|
||||
{
|
||||
for (LeafIndex i{ 0 }; i < size; i.val++) {
|
||||
const auto& node = node_at(i);
|
||||
if (node.blank()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (pred(i, node.leaf_node())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
using FilteredDirectPath =
|
||||
std::vector<std::tuple<NodeIndex, std::vector<NodeIndex>>>;
|
||||
FilteredDirectPath filtered_direct_path(NodeIndex index) const;
|
||||
|
||||
void truncate();
|
||||
|
||||
OptionalNode& node_at(NodeIndex n);
|
||||
const OptionalNode& node_at(NodeIndex n) const;
|
||||
OptionalNode& node_at(LeafIndex n);
|
||||
const OptionalNode& node_at(LeafIndex n) const;
|
||||
|
||||
TLS_SERIALIZABLE(nodes)
|
||||
|
||||
#if ENABLE_TREE_DUMP
|
||||
void dump() const;
|
||||
#endif
|
||||
|
||||
private:
|
||||
std::map<NodeIndex, bytes> hashes;
|
||||
|
||||
void clear_hash_all();
|
||||
void clear_hash_path(LeafIndex index);
|
||||
|
||||
bool has_parent_hash(NodeIndex child, const bytes& target_ph) const;
|
||||
|
||||
bytes parent_hash(const ParentNode& parent, NodeIndex copath_child) const;
|
||||
std::vector<bytes> parent_hashes(
|
||||
LeafIndex from,
|
||||
const FilteredDirectPath& fdp,
|
||||
const std::vector<UpdatePathNode>& path_nodes) const;
|
||||
|
||||
using TreeHashCache = std::map<NodeIndex, std::pair<size_t, bytes>>;
|
||||
const bytes& original_tree_hash(TreeHashCache& cache,
|
||||
NodeIndex index,
|
||||
std::vector<LeafIndex> parent_except) const;
|
||||
bytes original_parent_hash(TreeHashCache& cache,
|
||||
NodeIndex parent,
|
||||
NodeIndex sibling) const;
|
||||
|
||||
bool exists_in_tree(const HPKEPublicKey& key,
|
||||
std::optional<LeafIndex> except) const;
|
||||
bool exists_in_tree(const SignaturePublicKey& key,
|
||||
std::optional<LeafIndex> except) const;
|
||||
|
||||
OptionalNode blank_node;
|
||||
|
||||
friend struct TreeKEMPrivateKey;
|
||||
};
|
||||
|
||||
tls::ostream&
|
||||
operator<<(tls::ostream& str, const TreeKEMPublicKey& obj);
|
||||
tls::istream&
|
||||
operator>>(tls::istream& str, TreeKEMPublicKey& obj);
|
||||
|
||||
struct LeafNodeHashInput;
|
||||
struct ParentNodeHashInput;
|
||||
|
||||
} // namespace mlspp
|
||||
|
||||
namespace mlspp::tls {
|
||||
|
||||
TLS_VARIANT_MAP(mlspp::NodeType, mlspp::LeafNodeHashInput, leaf)
|
||||
TLS_VARIANT_MAP(mlspp::NodeType,
|
||||
mlspp::ParentNodeHashInput,
|
||||
parent)
|
||||
|
||||
TLS_VARIANT_MAP(mlspp::NodeType, mlspp::LeafNode, leaf)
|
||||
TLS_VARIANT_MAP(mlspp::NodeType, mlspp::ParentNode, parent)
|
||||
|
||||
} // namespace mlspp::tls
|
||||
4
DPP/mlspp/include/namespace.h
Executable file
4
DPP/mlspp/include/namespace.h
Executable file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
// Configurable top-level MLS namespace
|
||||
#define MLS_NAMESPACE ../include/dpp/mlspp/mls
|
||||
5
DPP/mlspp/include/version.h
Executable file
5
DPP/mlspp/include/version.h
Executable file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
/* Global version strings */
|
||||
extern const char VERSION[];
|
||||
extern const char HASHVAR[];
|
||||
4
DPP/mlspp/lib/CMakeLists.txt
Executable file
4
DPP/mlspp/lib/CMakeLists.txt
Executable file
@@ -0,0 +1,4 @@
|
||||
add_subdirectory(bytes)
|
||||
add_subdirectory(hpke)
|
||||
add_subdirectory(tls_syntax)
|
||||
add_subdirectory(mls_vectors)
|
||||
25
DPP/mlspp/lib/bytes/CMakeLists.txt
Executable file
25
DPP/mlspp/lib/bytes/CMakeLists.txt
Executable file
@@ -0,0 +1,25 @@
|
||||
set(CURRENT_LIB_NAME bytes)
|
||||
|
||||
###
|
||||
### Library Config
|
||||
###
|
||||
|
||||
file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
|
||||
file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
|
||||
|
||||
add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES})
|
||||
add_dependencies(${CURRENT_LIB_NAME} tls_syntax)
|
||||
include_directories("${PROJECT_SOURCE_DIR}/../bytes/include")
|
||||
target_link_libraries(${CURRENT_LIB_NAME} tls_syntax)
|
||||
target_include_directories(${CURRENT_LIB_NAME}
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include/${PROJECT_NAME}>
|
||||
)
|
||||
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include")
|
||||
|
||||
127
DPP/mlspp/lib/bytes/include/bytes/bytes.h
Executable file
127
DPP/mlspp/lib/bytes/include/bytes/bytes.h
Executable file
@@ -0,0 +1,127 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tls/tls_syntax.h>
|
||||
#include <vector>
|
||||
|
||||
namespace mlspp::bytes_ns {
|
||||
|
||||
struct bytes
|
||||
{
|
||||
// Ensure defaults
|
||||
bytes() = default;
|
||||
bytes(const bytes&) = default;
|
||||
bytes& operator=(const bytes&) = default;
|
||||
bytes(bytes&&) = default;
|
||||
bytes& operator=(bytes&&) = default;
|
||||
|
||||
// Zeroize on drop
|
||||
~bytes()
|
||||
{
|
||||
auto ptr = static_cast<volatile uint8_t*>(_data.data());
|
||||
std::fill(ptr, ptr + _data.size(), uint8_t(0));
|
||||
}
|
||||
|
||||
// Mimic std::vector ctors
|
||||
bytes(size_t count, const uint8_t& value = 0)
|
||||
: _data(count, value)
|
||||
{
|
||||
}
|
||||
|
||||
bytes(std::initializer_list<uint8_t> init)
|
||||
: _data(init)
|
||||
{
|
||||
}
|
||||
|
||||
template<size_t N>
|
||||
bytes(const std::array<uint8_t, N>& data)
|
||||
: _data(data.begin(), data.end())
|
||||
{
|
||||
}
|
||||
|
||||
// Slice out sub-vectors (to avoid an iterator ctor)
|
||||
bytes slice(size_t begin_index, size_t end_index) const
|
||||
{
|
||||
const auto begin_it = _data.begin() + begin_index;
|
||||
const auto end_it = _data.begin() + end_index;
|
||||
return std::vector<uint8_t>(begin_it, end_it);
|
||||
}
|
||||
|
||||
// Freely convert to/from std::vector
|
||||
bytes(const std::vector<uint8_t>& vec)
|
||||
: _data(vec)
|
||||
{
|
||||
}
|
||||
|
||||
bytes(std::vector<uint8_t>&& vec)
|
||||
: _data(vec)
|
||||
{
|
||||
}
|
||||
|
||||
operator const std::vector<uint8_t>&() const { return _data; }
|
||||
operator std::vector<uint8_t>&() { return _data; }
|
||||
operator std::vector<uint8_t>&&() && { return std::move(_data); }
|
||||
|
||||
const std::vector<uint8_t>& as_vec() const { return _data; }
|
||||
std::vector<uint8_t>& as_vec() { return _data; }
|
||||
|
||||
// Pass through methods
|
||||
auto data() const { return _data.data(); }
|
||||
auto data() { return _data.data(); }
|
||||
|
||||
auto size() const { return _data.size(); }
|
||||
auto empty() const { return _data.empty(); }
|
||||
|
||||
auto begin() const { return _data.begin(); }
|
||||
auto begin() { return _data.begin(); }
|
||||
|
||||
auto end() const { return _data.end(); }
|
||||
auto end() { return _data.end(); }
|
||||
|
||||
const auto& at(size_t pos) const { return _data.at(pos); }
|
||||
auto& at(size_t pos) { return _data.at(pos); }
|
||||
|
||||
void resize(size_t count) { _data.resize(count); }
|
||||
void reserve(size_t len) { _data.reserve(len); }
|
||||
void push_back(uint8_t byte) { _data.push_back(byte); }
|
||||
|
||||
// Equality operators
|
||||
bool operator==(const bytes& other) const;
|
||||
bool operator!=(const bytes& other) const;
|
||||
|
||||
bool operator==(const std::vector<uint8_t>& other) const;
|
||||
bool operator!=(const std::vector<uint8_t>& other) const;
|
||||
|
||||
// Arithmetic operators
|
||||
bytes& operator+=(const bytes& other);
|
||||
bytes operator+(const bytes& rhs) const;
|
||||
bytes operator^(const bytes& rhs) const;
|
||||
|
||||
// Sorting operators (to allow usage as map keys)
|
||||
bool operator<(const bytes& rhs) const;
|
||||
|
||||
// Other, external operators
|
||||
friend std::ostream& operator<<(std::ostream& out, const bytes& data);
|
||||
friend bool operator==(const std::vector<uint8_t>& lhs, const bytes& rhs);
|
||||
friend bool operator!=(const std::vector<uint8_t>& lhs, const bytes& rhs);
|
||||
|
||||
// TLS syntax serialization
|
||||
TLS_SERIALIZABLE(_data);
|
||||
|
||||
private:
|
||||
std::vector<uint8_t> _data;
|
||||
};
|
||||
|
||||
std::string
|
||||
to_ascii(const bytes& data);
|
||||
|
||||
bytes
|
||||
from_ascii(const std::string& ascii);
|
||||
|
||||
std::string
|
||||
to_hex(const bytes& data);
|
||||
|
||||
bytes
|
||||
from_hex(const std::string& hex);
|
||||
|
||||
} // namespace mlspp::bytes_ns
|
||||
61
DPP/mlspp/lib/bytes/include/tls/compat.h
Executable file
61
DPP/mlspp/lib/bytes/include/tls/compat.h
Executable file
@@ -0,0 +1,61 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
|
||||
#ifdef VARIANT_COMPAT
|
||||
#include <variant.hpp>
|
||||
#else
|
||||
#include <variant>
|
||||
#endif // VARIANT_COMPAT
|
||||
|
||||
namespace mlspp::tls {
|
||||
|
||||
namespace var = std;
|
||||
|
||||
// In a similar vein, we provide our own safe accessors for std::optional, since
|
||||
// std::optional::value() is not available on macOS 10.11.
|
||||
namespace opt {
|
||||
|
||||
template<typename T>
|
||||
T&
|
||||
get(std::optional<T>& opt)
|
||||
{
|
||||
if (!opt) {
|
||||
throw std::runtime_error("bad_optional_access");
|
||||
}
|
||||
return *opt;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
const T&
|
||||
get(const std::optional<T>& opt)
|
||||
{
|
||||
if (!opt) {
|
||||
throw std::runtime_error("bad_optional_access");
|
||||
}
|
||||
return *opt;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T&&
|
||||
get(std::optional<T>&& opt)
|
||||
{
|
||||
if (!opt) {
|
||||
throw std::runtime_error("bad_optional_access");
|
||||
}
|
||||
return std::move(*opt);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
const T&&
|
||||
get(const std::optional<T>&& opt)
|
||||
{
|
||||
if (!opt) {
|
||||
throw std::runtime_error("bad_optional_access");
|
||||
}
|
||||
return std::move(*opt);
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mlspp::tls
|
||||
569
DPP/mlspp/lib/bytes/include/tls/tls_syntax.h
Executable file
569
DPP/mlspp/lib/bytes/include/tls/tls_syntax.h
Executable file
@@ -0,0 +1,569 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include <tls/compat.h>
|
||||
|
||||
namespace mlspp::tls {
|
||||
|
||||
// For indicating no min or max in vector definitions
|
||||
const size_t none = std::numeric_limits<size_t>::max();
|
||||
|
||||
class WriteError : public std::invalid_argument
|
||||
{
|
||||
public:
|
||||
using parent = std::invalid_argument;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
class ReadError : public std::invalid_argument
|
||||
{
|
||||
public:
|
||||
using parent = std::invalid_argument;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
///
|
||||
/// Declarations of Streams and Traits
|
||||
///
|
||||
|
||||
class ostream
|
||||
{
|
||||
public:
|
||||
static const size_t none = std::numeric_limits<size_t>::max();
|
||||
|
||||
void write_raw(const std::vector<uint8_t>& bytes);
|
||||
|
||||
const std::vector<uint8_t>& bytes() const { return _buffer; }
|
||||
size_t size() const { return _buffer.size(); }
|
||||
bool empty() const { return _buffer.empty(); }
|
||||
|
||||
private:
|
||||
std::vector<uint8_t> _buffer;
|
||||
ostream& write_uint(uint64_t value, int length);
|
||||
|
||||
friend ostream& operator<<(ostream& out, bool data);
|
||||
friend ostream& operator<<(ostream& out, uint8_t data);
|
||||
friend ostream& operator<<(ostream& out, uint16_t data);
|
||||
friend ostream& operator<<(ostream& out, uint32_t data);
|
||||
friend ostream& operator<<(ostream& out, uint64_t data);
|
||||
|
||||
template<typename T>
|
||||
friend ostream& operator<<(ostream& out, const std::vector<T>& data);
|
||||
|
||||
friend struct varint;
|
||||
};
|
||||
|
||||
class istream
|
||||
{
|
||||
public:
|
||||
istream(const std::vector<uint8_t>& data)
|
||||
: _buffer(data)
|
||||
{
|
||||
// So that we can use the constant-time pop_back
|
||||
std::reverse(_buffer.begin(), _buffer.end());
|
||||
}
|
||||
|
||||
size_t size() const { return _buffer.size(); }
|
||||
bool empty() const { return _buffer.empty(); }
|
||||
|
||||
std::vector<uint8_t> bytes()
|
||||
{
|
||||
auto bytes = _buffer;
|
||||
std::reverse(bytes.begin(), bytes.end());
|
||||
return bytes;
|
||||
}
|
||||
|
||||
private:
|
||||
istream() {}
|
||||
std::vector<uint8_t> _buffer;
|
||||
uint8_t next();
|
||||
|
||||
template<typename T>
|
||||
istream& read_uint(T& data, size_t length)
|
||||
{
|
||||
uint64_t value = 0;
|
||||
for (size_t i = 0; i < length; i += 1) {
|
||||
value = (value << unsigned(8)) + next();
|
||||
}
|
||||
data = static_cast<T>(value);
|
||||
return *this;
|
||||
}
|
||||
|
||||
friend istream& operator>>(istream& in, bool& data);
|
||||
friend istream& operator>>(istream& in, uint8_t& data);
|
||||
friend istream& operator>>(istream& in, uint16_t& data);
|
||||
friend istream& operator>>(istream& in, uint32_t& data);
|
||||
friend istream& operator>>(istream& in, uint64_t& data);
|
||||
|
||||
template<typename T>
|
||||
friend istream& operator>>(istream& in, std::vector<T>& data);
|
||||
|
||||
friend struct varint;
|
||||
};
|
||||
|
||||
// Traits must have static encode and decode methods, of the following form:
|
||||
//
|
||||
// static ostream& encode(ostream& str, const T& val);
|
||||
// static istream& decode(istream& str, T& val);
|
||||
//
|
||||
// Trait types will never be constructed; only these static methods are used.
|
||||
// The value arguments to encode and decode can be as strict or as loose as
|
||||
// desired.
|
||||
//
|
||||
// Ultimately, all interesting encoding should be done through traits.
|
||||
//
|
||||
// * vectors
|
||||
// * variants
|
||||
// * varints
|
||||
|
||||
struct pass
|
||||
{
|
||||
template<typename T>
|
||||
static ostream& encode(ostream& str, const T& val);
|
||||
|
||||
template<typename T>
|
||||
static istream& decode(istream& str, T& val);
|
||||
};
|
||||
|
||||
template<typename Ts>
|
||||
struct variant
|
||||
{
|
||||
template<typename... Tp>
|
||||
static inline Ts type(const var::variant<Tp...>& data);
|
||||
|
||||
template<typename... Tp>
|
||||
static ostream& encode(ostream& str, const var::variant<Tp...>& data);
|
||||
|
||||
template<size_t I = 0, typename Te, typename... Tp>
|
||||
static inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
read_variant(istream&, Te, var::variant<Tp...>&);
|
||||
|
||||
template<size_t I = 0, typename Te, typename... Tp>
|
||||
static inline typename std::enable_if <
|
||||
I<sizeof...(Tp), void>::type read_variant(istream& str,
|
||||
Te target_type,
|
||||
var::variant<Tp...>& v);
|
||||
|
||||
template<typename... Tp>
|
||||
static istream& decode(istream& str, var::variant<Tp...>& data);
|
||||
};
|
||||
|
||||
struct varint
|
||||
{
|
||||
static ostream& encode(ostream& str, const uint64_t& val);
|
||||
static istream& decode(istream& str, uint64_t& val);
|
||||
};
|
||||
|
||||
///
|
||||
/// Writer implementations
|
||||
///
|
||||
|
||||
// Primitive writers defined in .cpp file
|
||||
|
||||
// Array writer
|
||||
template<typename T, size_t N>
|
||||
ostream&
|
||||
operator<<(ostream& out, const std::array<T, N>& data)
|
||||
{
|
||||
for (const auto& item : data) {
|
||||
out << item;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// Optional writer
|
||||
template<typename T>
|
||||
ostream&
|
||||
operator<<(ostream& out, const std::optional<T>& opt)
|
||||
{
|
||||
if (!opt) {
|
||||
return out << uint8_t(0);
|
||||
}
|
||||
|
||||
return out << uint8_t(1) << opt::get(opt);
|
||||
}
|
||||
|
||||
// Enum writer
|
||||
template<typename T, std::enable_if_t<std::is_enum<T>::value, int> = 0>
|
||||
ostream&
|
||||
operator<<(ostream& str, const T& val)
|
||||
{
|
||||
auto u = static_cast<std::underlying_type_t<T>>(val);
|
||||
return str << u;
|
||||
}
|
||||
|
||||
// Vector writer
|
||||
template<typename T>
|
||||
ostream&
|
||||
operator<<(ostream& str, const std::vector<T>& vec)
|
||||
{
|
||||
// Pre-encode contents
|
||||
ostream temp;
|
||||
for (const auto& item : vec) {
|
||||
temp << item;
|
||||
}
|
||||
|
||||
// Write the encoded length, then the pre-encoded data
|
||||
varint::encode(str, temp._buffer.size());
|
||||
str.write_raw(temp.bytes());
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
///
|
||||
/// Reader implementations
|
||||
///
|
||||
|
||||
// Primitive type readers defined in .cpp file
|
||||
|
||||
// Array reader
|
||||
template<typename T, size_t N>
|
||||
istream&
|
||||
operator>>(istream& in, std::array<T, N>& data)
|
||||
{
|
||||
for (auto& item : data) {
|
||||
in >> item;
|
||||
}
|
||||
return in;
|
||||
}
|
||||
|
||||
// Optional reader
|
||||
template<typename T>
|
||||
istream&
|
||||
operator>>(istream& in, std::optional<T>& opt)
|
||||
{
|
||||
uint8_t present = 0;
|
||||
in >> present;
|
||||
|
||||
switch (present) {
|
||||
case 0:
|
||||
opt.reset();
|
||||
return in;
|
||||
|
||||
case 1:
|
||||
opt.emplace();
|
||||
return in >> opt::get(opt);
|
||||
|
||||
default:
|
||||
throw std::invalid_argument("Malformed optional");
|
||||
}
|
||||
}
|
||||
|
||||
// Enum reader
|
||||
// XXX(rlb): It would be nice if this could enforce that the values are valid,
|
||||
// but C++ doesn't seem to have that ability. When used as a tag for variants,
|
||||
// the variant reader will enforce, at least.
|
||||
template<typename T, std::enable_if_t<std::is_enum<T>::value, int> = 0>
|
||||
istream&
|
||||
operator>>(istream& str, T& val)
|
||||
{
|
||||
std::underlying_type_t<T> u;
|
||||
str >> u;
|
||||
val = static_cast<T>(u);
|
||||
return str;
|
||||
}
|
||||
|
||||
// Vector reader
|
||||
template<typename T>
|
||||
istream&
|
||||
operator>>(istream& str, std::vector<T>& vec)
|
||||
{
|
||||
// Read the encoded data size
|
||||
auto size = uint64_t(0);
|
||||
varint::decode(str, size);
|
||||
if (size > str._buffer.size()) {
|
||||
throw ReadError("Vector is longer than remaining data");
|
||||
}
|
||||
|
||||
// Read the elements of the vector
|
||||
// NB: Remember that we store the vector in reverse order
|
||||
// NB: This requires that T be default-constructible
|
||||
istream r;
|
||||
r._buffer =
|
||||
std::vector<uint8_t>{ str._buffer.end() - size, str._buffer.end() };
|
||||
|
||||
vec.clear();
|
||||
while (r._buffer.size() > 0) {
|
||||
vec.emplace_back();
|
||||
r >> vec.back();
|
||||
}
|
||||
|
||||
// Truncate the primary buffer
|
||||
str._buffer.erase(str._buffer.end() - size, str._buffer.end());
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
// Abbreviations
|
||||
template<typename T>
|
||||
std::vector<uint8_t>
|
||||
marshal(const T& value)
|
||||
{
|
||||
ostream w;
|
||||
w << value;
|
||||
return w.bytes();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void
|
||||
unmarshal(const std::vector<uint8_t>& data, T& value)
|
||||
{
|
||||
istream r(data);
|
||||
r >> value;
|
||||
}
|
||||
|
||||
template<typename T, typename... Tp>
|
||||
T
|
||||
get(const std::vector<uint8_t>& data, Tp... args)
|
||||
{
|
||||
T value(args...);
|
||||
unmarshal(data, value);
|
||||
return value;
|
||||
}
|
||||
|
||||
// Use this macro to define struct serialization with minimal boilerplate
|
||||
#define TLS_SERIALIZABLE(...) \
|
||||
static const bool _tls_serializable = true; \
|
||||
auto _tls_fields_r() \
|
||||
{ \
|
||||
return std::forward_as_tuple(__VA_ARGS__); \
|
||||
} \
|
||||
auto _tls_fields_w() const \
|
||||
{ \
|
||||
return std::forward_as_tuple(__VA_ARGS__); \
|
||||
}
|
||||
|
||||
// If your struct contains nontrivial members (e.g., vectors), use this to
|
||||
// define traits for them.
|
||||
#define TLS_TRAITS(...) \
|
||||
static const bool _tls_has_traits = true; \
|
||||
using _tls_traits = std::tuple<__VA_ARGS__>;
|
||||
|
||||
template<typename T>
|
||||
struct is_serializable
|
||||
{
|
||||
template<typename U>
|
||||
static std::true_type test(decltype(U::_tls_serializable));
|
||||
|
||||
template<typename U>
|
||||
static std::false_type test(...);
|
||||
|
||||
static const bool value = decltype(test<T>(true))::value;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct has_traits
|
||||
{
|
||||
template<typename U>
|
||||
static std::true_type test(decltype(U::_tls_has_traits));
|
||||
|
||||
template<typename U>
|
||||
static std::false_type test(...);
|
||||
|
||||
static const bool value = decltype(test<T>(true))::value;
|
||||
};
|
||||
|
||||
///
|
||||
/// Trait implementations
|
||||
///
|
||||
|
||||
// Pass-through (normal encoding/decoding)
|
||||
template<typename T>
|
||||
ostream&
|
||||
pass::encode(ostream& str, const T& val)
|
||||
{
|
||||
return str << val;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
istream&
|
||||
pass::decode(istream& str, T& val)
|
||||
{
|
||||
return str >> val;
|
||||
}
|
||||
|
||||
// Variant encoding
|
||||
template<typename Ts, typename Tv>
|
||||
constexpr Ts
|
||||
variant_map();
|
||||
|
||||
#define TLS_VARIANT_MAP(EnumType, MappedType, enum_value) \
|
||||
template<> \
|
||||
constexpr EnumType variant_map<EnumType, MappedType>() \
|
||||
{ \
|
||||
return EnumType::enum_value; \
|
||||
}
|
||||
|
||||
template<typename Ts>
|
||||
template<typename... Tp>
|
||||
inline Ts
|
||||
variant<Ts>::type(const var::variant<Tp...>& data)
|
||||
{
|
||||
const auto get_type = [](const auto& v) {
|
||||
return variant_map<Ts, std::decay_t<decltype(v)>>();
|
||||
};
|
||||
return var::visit(get_type, data);
|
||||
}
|
||||
|
||||
template<typename Ts>
|
||||
template<typename... Tp>
|
||||
ostream&
|
||||
variant<Ts>::encode(ostream& str, const var::variant<Tp...>& data)
|
||||
{
|
||||
const auto write_variant = [&str](auto&& v) {
|
||||
using Tv = std::decay_t<decltype(v)>;
|
||||
str << variant_map<Ts, Tv>() << v;
|
||||
};
|
||||
var::visit(write_variant, data);
|
||||
return str;
|
||||
}
|
||||
|
||||
template<typename Ts>
|
||||
template<size_t I, typename Te, typename... Tp>
|
||||
inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
variant<Ts>::read_variant(istream&, Te, var::variant<Tp...>&)
|
||||
{
|
||||
throw ReadError("Invalid variant type label");
|
||||
}
|
||||
|
||||
template<typename Ts>
|
||||
template<size_t I, typename Te, typename... Tp>
|
||||
inline
|
||||
typename std::enable_if < I<sizeof...(Tp), void>::type
|
||||
variant<Ts>::read_variant(istream& str,
|
||||
Te target_type,
|
||||
var::variant<Tp...>& v)
|
||||
{
|
||||
using Tc = var::variant_alternative_t<I, var::variant<Tp...>>;
|
||||
if (variant_map<Ts, Tc>() == target_type) {
|
||||
str >> v.template emplace<I>();
|
||||
return;
|
||||
}
|
||||
|
||||
read_variant<I + 1>(str, target_type, v);
|
||||
}
|
||||
|
||||
template<typename Ts>
|
||||
template<typename... Tp>
|
||||
istream&
|
||||
variant<Ts>::decode(istream& str, var::variant<Tp...>& data)
|
||||
{
|
||||
Ts target_type;
|
||||
str >> target_type;
|
||||
read_variant(str, target_type, data);
|
||||
return str;
|
||||
}
|
||||
|
||||
// Struct writer without traits (enabled by macro)
|
||||
template<size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
write_tuple(ostream&, const std::tuple<Tp...>&)
|
||||
{
|
||||
}
|
||||
|
||||
template<size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if <
|
||||
I<sizeof...(Tp), void>::type
|
||||
write_tuple(ostream& str, const std::tuple<Tp...>& t)
|
||||
{
|
||||
str << std::get<I>(t);
|
||||
write_tuple<I + 1, Tp...>(str, t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline
|
||||
typename std::enable_if<is_serializable<T>::value && !has_traits<T>::value,
|
||||
ostream&>::type
|
||||
operator<<(ostream& str, const T& obj)
|
||||
{
|
||||
write_tuple(str, obj._tls_fields_w());
|
||||
return str;
|
||||
}
|
||||
|
||||
// Struct writer with traits (enabled by macro)
|
||||
template<typename Tr, size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
write_tuple_traits(ostream&, const std::tuple<Tp...>&)
|
||||
{
|
||||
}
|
||||
|
||||
template<typename Tr, size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if <
|
||||
I<sizeof...(Tp), void>::type
|
||||
write_tuple_traits(ostream& str, const std::tuple<Tp...>& t)
|
||||
{
|
||||
std::tuple_element_t<I, Tr>::encode(str, std::get<I>(t));
|
||||
write_tuple_traits<Tr, I + 1, Tp...>(str, t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline
|
||||
typename std::enable_if<is_serializable<T>::value && has_traits<T>::value,
|
||||
ostream&>::type
|
||||
operator<<(ostream& str, const T& obj)
|
||||
{
|
||||
write_tuple_traits<typename T::_tls_traits>(str, obj._tls_fields_w());
|
||||
return str;
|
||||
}
|
||||
|
||||
// Struct reader without traits (enabled by macro)
|
||||
template<size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
read_tuple(istream&, const std::tuple<Tp...>&)
|
||||
{
|
||||
}
|
||||
|
||||
template<size_t I = 0, typename... Tp>
|
||||
inline
|
||||
typename std::enable_if < I<sizeof...(Tp), void>::type
|
||||
read_tuple(istream& str, const std::tuple<Tp...>& t)
|
||||
{
|
||||
str >> std::get<I>(t);
|
||||
read_tuple<I + 1, Tp...>(str, t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline
|
||||
typename std::enable_if<is_serializable<T>::value && !has_traits<T>::value,
|
||||
istream&>::type
|
||||
operator>>(istream& str, T& obj)
|
||||
{
|
||||
read_tuple(str, obj._tls_fields_r());
|
||||
return str;
|
||||
}
|
||||
|
||||
// Struct reader with traits (enabled by macro)
|
||||
template<typename Tr, size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
read_tuple_traits(istream&, const std::tuple<Tp...>&)
|
||||
{
|
||||
}
|
||||
|
||||
template<typename Tr, size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if <
|
||||
I<sizeof...(Tp), void>::type
|
||||
read_tuple_traits(istream& str, const std::tuple<Tp...>& t)
|
||||
{
|
||||
std::tuple_element_t<I, Tr>::decode(str, std::get<I>(t));
|
||||
read_tuple_traits<Tr, I + 1, Tp...>(str, t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline
|
||||
typename std::enable_if<is_serializable<T>::value && has_traits<T>::value,
|
||||
istream&>::type
|
||||
operator>>(istream& str, T& obj)
|
||||
{
|
||||
read_tuple_traits<typename T::_tls_traits>(str, obj._tls_fields_r());
|
||||
return str;
|
||||
}
|
||||
|
||||
} // namespace mlspp::tls
|
||||
146
DPP/mlspp/lib/bytes/src/bytes.cpp
Executable file
146
DPP/mlspp/lib/bytes/src/bytes.cpp
Executable file
@@ -0,0 +1,146 @@
|
||||
#include <bytes/bytes.h>
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace mlspp::bytes_ns {
|
||||
|
||||
bool
|
||||
bytes::operator==(const bytes& other) const
|
||||
{
|
||||
return *this == other._data;
|
||||
}
|
||||
|
||||
bool
|
||||
bytes::operator!=(const bytes& other) const
|
||||
{
|
||||
return !(*this == other._data);
|
||||
}
|
||||
|
||||
bool
|
||||
bytes::operator==(const std::vector<uint8_t>& other) const
|
||||
{
|
||||
const size_t size = other.size();
|
||||
if (_data.size() != size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
unsigned char diff = 0;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
// Not sure why the linter thinks `diff` is signed
|
||||
// NOLINTNEXTLINE(hicpp-signed-bitwise)
|
||||
diff |= (_data.at(i) ^ other.at(i));
|
||||
}
|
||||
return (diff == 0);
|
||||
}
|
||||
|
||||
bool
|
||||
bytes::operator!=(const std::vector<uint8_t>& other) const
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
bytes&
|
||||
bytes::operator+=(const bytes& other)
|
||||
{
|
||||
// Not sure what the default argument is here
|
||||
// NOLINTNEXTLINE(fuchsia-default-arguments)
|
||||
_data.insert(end(), other.begin(), other.end());
|
||||
return *this;
|
||||
}
|
||||
|
||||
bytes
|
||||
bytes::operator+(const bytes& rhs) const
|
||||
{
|
||||
bytes out = *this;
|
||||
out += rhs;
|
||||
return out;
|
||||
}
|
||||
|
||||
bool
|
||||
bytes::operator<(const bytes& rhs) const
|
||||
{
|
||||
return _data < rhs._data;
|
||||
}
|
||||
|
||||
bytes
|
||||
bytes::operator^(const bytes& rhs) const
|
||||
{
|
||||
if (size() != rhs.size()) {
|
||||
throw std::invalid_argument("XOR with unequal size");
|
||||
}
|
||||
|
||||
bytes out = *this;
|
||||
for (size_t i = 0; i < size(); ++i) {
|
||||
out.at(i) ^= rhs.at(i);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
std::string
|
||||
to_ascii(const bytes& data)
|
||||
{
|
||||
return { data.begin(), data.end() };
|
||||
}
|
||||
|
||||
bytes
|
||||
from_ascii(const std::string& ascii)
|
||||
{
|
||||
return std::vector<uint8_t>(ascii.begin(), ascii.end());
|
||||
}
|
||||
|
||||
std::string
|
||||
to_hex(const bytes& data)
|
||||
{
|
||||
std::stringstream hex(std::ios_base::out);
|
||||
hex.flags(std::ios::hex);
|
||||
for (const auto& byte : data) {
|
||||
hex << std::setw(2) << std::setfill('0') << int(byte);
|
||||
}
|
||||
return hex.str();
|
||||
}
|
||||
|
||||
bytes
|
||||
from_hex(const std::string& hex)
|
||||
{
|
||||
if (hex.length() % 2 == 1) {
|
||||
throw std::invalid_argument("Odd-length hex string");
|
||||
}
|
||||
|
||||
auto len = hex.length() / 2;
|
||||
auto out = bytes(len);
|
||||
for (size_t i = 0; i < len; i += 1) {
|
||||
const std::string byte = hex.substr(2 * i, 2);
|
||||
out.at(i) = static_cast<uint8_t>(strtol(byte.c_str(), nullptr, 16));
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream&
|
||||
operator<<(std::ostream& out, const bytes& data)
|
||||
{
|
||||
// Adjust this threshold to make output more compact
|
||||
const size_t threshold = 0xffff;
|
||||
if (data.size() < threshold) {
|
||||
return out << to_hex(data);
|
||||
}
|
||||
|
||||
return out << to_hex(data.slice(0, threshold)) << "...";
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const std::vector<uint8_t>& lhs, const bytes_ns::bytes& rhs)
|
||||
{
|
||||
return rhs == lhs;
|
||||
}
|
||||
|
||||
bool
|
||||
operator!=(const std::vector<uint8_t>& lhs, const bytes_ns::bytes& rhs)
|
||||
{
|
||||
return rhs != lhs;
|
||||
}
|
||||
|
||||
} // namespace mlspp::bytes_ns
|
||||
57
DPP/mlspp/lib/hpke/CMakeLists.txt
Executable file
57
DPP/mlspp/lib/hpke/CMakeLists.txt
Executable file
@@ -0,0 +1,57 @@
|
||||
set(CURRENT_LIB_NAME hpke)
|
||||
|
||||
###
|
||||
### Library Config
|
||||
###
|
||||
|
||||
file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
|
||||
file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
|
||||
|
||||
# -Werror=dangling-reference
|
||||
|
||||
add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES})
|
||||
add_dependencies(${CURRENT_LIB_NAME} bytes tls_syntax)
|
||||
target_include_directories(${CURRENT_LIB_NAME}
|
||||
PRIVATE
|
||||
"${JSON_INCLUDE_INTERFACE}")
|
||||
|
||||
target_link_libraries(${CURRENT_LIB_NAME}
|
||||
PUBLIC
|
||||
bytes tls_syntax
|
||||
)
|
||||
|
||||
target_include_directories(${CURRENT_LIB_NAME}
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include/${PROJECT_NAME}>
|
||||
PRIVATE
|
||||
${OPENSSL_INCLUDE_DIR}
|
||||
)
|
||||
|
||||
# Private statically linked dependencies
|
||||
target_link_libraries(${CURRENT_LIB_NAME} PRIVATE
|
||||
mlspp
|
||||
mls_vectors
|
||||
bytes
|
||||
tls_syntax
|
||||
)
|
||||
|
||||
target_compile_options(
|
||||
"${CURRENT_LIB_NAME}"
|
||||
PUBLIC
|
||||
"$<$<PLATFORM_ID:Windows>:/bigobj;/Zc:preprocessor>"
|
||||
PRIVATE
|
||||
"$<$<PLATFORM_ID:Windows>:$<$<CONFIG:Debug>:/sdl;/Od;/DEBUG;/MP;/DFD_SETSIZE=1024>>"
|
||||
"$<$<PLATFORM_ID:Windows>:$<$<CONFIG:Release>:/O2;/Oi;/Oy;/GL;/Gy;/sdl;/MP;/DFD_SETSIZE=1024>>"
|
||||
"$<$<PLATFORM_ID:Linux>:$<$<CONFIG:Debug>:-Wall;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-g;-Og;-fPIC>>"
|
||||
"$<$<PLATFORM_ID:Linux>:$<$<CONFIG:Release>:-Wall;-Wempty-body;-Wno-psabi;-Wunknown-pragmas;-Wignored-qualifiers;-Wimplicit-fallthrough;-Wmissing-field-initializers;-Wsign-compare;-Wtype-limits;-Wuninitialized;-Wshift-negative-value;-pthread;-O3;-fPIC>>"
|
||||
"${AVX_FLAG}"
|
||||
)
|
||||
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include")
|
||||
# For nlohmann/json.hpp
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../../include")
|
||||
20
DPP/mlspp/lib/hpke/include/hpke/base64.h
Executable file
20
DPP/mlspp/lib/hpke/include/hpke/base64.h
Executable file
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <bytes/bytes.h>
|
||||
using namespace mlspp::bytes_ns;
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
std::string
|
||||
to_base64(const bytes& data);
|
||||
|
||||
std::string
|
||||
to_base64url(const bytes& data);
|
||||
|
||||
bytes
|
||||
from_base64(const std::string& enc);
|
||||
|
||||
bytes
|
||||
from_base64url(const std::string& enc);
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
75
DPP/mlspp/lib/hpke/include/hpke/certificate.h
Executable file
75
DPP/mlspp/lib/hpke/include/hpke/certificate.h
Executable file
@@ -0,0 +1,75 @@
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
#include <bytes/bytes.h>
|
||||
#include <chrono>
|
||||
#include <hpke/signature.h>
|
||||
#include <map>
|
||||
|
||||
using namespace mlspp::bytes_ns;
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
struct Certificate
|
||||
{
|
||||
private:
|
||||
struct ParsedCertificate;
|
||||
std::unique_ptr<ParsedCertificate> parsed_cert;
|
||||
|
||||
public:
|
||||
struct NameType
|
||||
{
|
||||
static const int organization;
|
||||
static const int common_name;
|
||||
static const int organizational_unit;
|
||||
static const int country;
|
||||
static const int serial_number;
|
||||
static const int state_or_province_name;
|
||||
};
|
||||
|
||||
using ParsedName = std::map<int, std::string>;
|
||||
|
||||
// Certificate Expiration Status
|
||||
enum struct ExpirationStatus
|
||||
{
|
||||
inactive, // now < notBefore
|
||||
active, // notBefore < now < notAfter
|
||||
expired, // notAfter < now
|
||||
};
|
||||
|
||||
explicit Certificate(const bytes& der);
|
||||
explicit Certificate(std::unique_ptr<ParsedCertificate>&& parsed_cert_in);
|
||||
Certificate() = delete;
|
||||
Certificate(const Certificate& other);
|
||||
~Certificate();
|
||||
|
||||
static std::vector<Certificate> parse_pem(const bytes& pem);
|
||||
|
||||
bool valid_from(const Certificate& parent) const;
|
||||
|
||||
// Accessors for parsed certificate elements
|
||||
uint64_t issuer_hash() const;
|
||||
uint64_t subject_hash() const;
|
||||
ParsedName issuer() const;
|
||||
ParsedName subject() const;
|
||||
bool is_ca() const;
|
||||
ExpirationStatus expiration_status() const;
|
||||
std::optional<bytes> subject_key_id() const;
|
||||
std::optional<bytes> authority_key_id() const;
|
||||
std::vector<std::string> email_addresses() const;
|
||||
std::vector<std::string> dns_names() const;
|
||||
bytes hash() const;
|
||||
std::chrono::system_clock::time_point not_before() const;
|
||||
std::chrono::system_clock::time_point not_after() const;
|
||||
Signature::ID public_key_algorithm() const;
|
||||
Signature::ID signature_algorithm() const;
|
||||
|
||||
const std::unique_ptr<Signature::PublicKey> public_key;
|
||||
const bytes raw;
|
||||
};
|
||||
|
||||
bool
|
||||
operator==(const Certificate& lhs, const Certificate& rhs);
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
37
DPP/mlspp/lib/hpke/include/hpke/digest.h
Executable file
37
DPP/mlspp/lib/hpke/include/hpke/digest.h
Executable file
@@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <bytes/bytes.h>
|
||||
|
||||
using namespace mlspp::bytes_ns;
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
struct Digest
|
||||
{
|
||||
enum struct ID
|
||||
{
|
||||
SHA256,
|
||||
SHA384,
|
||||
SHA512,
|
||||
};
|
||||
|
||||
template<ID id>
|
||||
static const Digest& get();
|
||||
|
||||
const ID id;
|
||||
|
||||
bytes hash(const bytes& data) const;
|
||||
bytes hmac(const bytes& key, const bytes& data) const;
|
||||
|
||||
const size_t hash_size;
|
||||
|
||||
private:
|
||||
explicit Digest(ID id);
|
||||
|
||||
bytes hmac_for_hkdf_extract(const bytes& key, const bytes& data) const;
|
||||
friend struct HKDF;
|
||||
};
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
253
DPP/mlspp/lib/hpke/include/hpke/hpke.h
Executable file
253
DPP/mlspp/lib/hpke/include/hpke/hpke.h
Executable file
@@ -0,0 +1,253 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
#include <bytes/bytes.h>
|
||||
using namespace mlspp::bytes_ns;
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
struct KEM
|
||||
{
|
||||
enum struct ID : uint16_t
|
||||
{
|
||||
DHKEM_P256_SHA256 = 0x0010,
|
||||
DHKEM_P384_SHA384 = 0x0011,
|
||||
DHKEM_P521_SHA512 = 0x0012,
|
||||
DHKEM_X25519_SHA256 = 0x0020,
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
DHKEM_X448_SHA512 = 0x0021,
|
||||
#endif
|
||||
};
|
||||
|
||||
template<KEM::ID>
|
||||
static const KEM& get();
|
||||
|
||||
virtual ~KEM() = default;
|
||||
|
||||
struct PublicKey
|
||||
{
|
||||
virtual ~PublicKey() = default;
|
||||
};
|
||||
|
||||
struct PrivateKey
|
||||
{
|
||||
virtual ~PrivateKey() = default;
|
||||
virtual std::unique_ptr<PublicKey> public_key() const = 0;
|
||||
};
|
||||
|
||||
const ID id;
|
||||
const size_t secret_size;
|
||||
const size_t enc_size;
|
||||
const size_t pk_size;
|
||||
const size_t sk_size;
|
||||
|
||||
virtual std::unique_ptr<PrivateKey> generate_key_pair() const = 0;
|
||||
virtual std::unique_ptr<PrivateKey> derive_key_pair(
|
||||
const bytes& ikm) const = 0;
|
||||
|
||||
virtual bytes serialize(const PublicKey& pk) const = 0;
|
||||
virtual std::unique_ptr<PublicKey> deserialize(const bytes& enc) const = 0;
|
||||
|
||||
virtual bytes serialize_private(const PrivateKey& sk) const;
|
||||
virtual std::unique_ptr<PrivateKey> deserialize_private(
|
||||
const bytes& skm) const;
|
||||
|
||||
// (shared_secret, enc)
|
||||
virtual std::pair<bytes, bytes> encap(const PublicKey& pkR) const = 0;
|
||||
virtual bytes decap(const bytes& enc, const PrivateKey& skR) const = 0;
|
||||
|
||||
// (shared_secret, enc)
|
||||
virtual std::pair<bytes, bytes> auth_encap(const PublicKey& pkR,
|
||||
const PrivateKey& skS) const;
|
||||
virtual bytes auth_decap(const bytes& enc,
|
||||
const PublicKey& pkS,
|
||||
const PrivateKey& skR) const;
|
||||
|
||||
protected:
|
||||
KEM(ID id_in,
|
||||
size_t secret_size_in,
|
||||
size_t enc_size_in,
|
||||
size_t pk_size_in,
|
||||
size_t sk_size_in);
|
||||
};
|
||||
|
||||
struct KDF
|
||||
{
|
||||
enum struct ID : uint16_t
|
||||
{
|
||||
HKDF_SHA256 = 0x0001,
|
||||
HKDF_SHA384 = 0x0002,
|
||||
HKDF_SHA512 = 0x0003,
|
||||
};
|
||||
|
||||
template<KDF::ID id>
|
||||
static const KDF& get();
|
||||
|
||||
virtual ~KDF() = default;
|
||||
|
||||
const ID id;
|
||||
const size_t hash_size;
|
||||
|
||||
virtual bytes extract(const bytes& salt, const bytes& ikm) const = 0;
|
||||
virtual bytes expand(const bytes& prk,
|
||||
const bytes& info,
|
||||
size_t size) const = 0;
|
||||
|
||||
bytes labeled_extract(const bytes& suite_id,
|
||||
const bytes& salt,
|
||||
const bytes& label,
|
||||
const bytes& ikm) const;
|
||||
bytes labeled_expand(const bytes& suite_id,
|
||||
const bytes& prk,
|
||||
const bytes& label,
|
||||
const bytes& info,
|
||||
size_t size) const;
|
||||
|
||||
protected:
|
||||
KDF(ID id_in, size_t hash_size_in);
|
||||
};
|
||||
|
||||
struct AEAD
|
||||
{
|
||||
enum struct ID : uint16_t
|
||||
{
|
||||
AES_128_GCM = 0x0001,
|
||||
AES_256_GCM = 0x0002,
|
||||
CHACHA20_POLY1305 = 0x0003,
|
||||
|
||||
// Reserved identifier for pseudo-AEAD on contexts that only allow export
|
||||
export_only = 0xffff,
|
||||
};
|
||||
|
||||
template<AEAD::ID id>
|
||||
static const AEAD& get();
|
||||
|
||||
virtual ~AEAD() = default;
|
||||
|
||||
const ID id;
|
||||
const size_t key_size;
|
||||
const size_t nonce_size;
|
||||
|
||||
virtual bytes seal(const bytes& key,
|
||||
const bytes& nonce,
|
||||
const bytes& aad,
|
||||
const bytes& pt) const = 0;
|
||||
virtual std::optional<bytes> open(const bytes& key,
|
||||
const bytes& nonce,
|
||||
const bytes& aad,
|
||||
const bytes& ct) const = 0;
|
||||
|
||||
protected:
|
||||
AEAD(ID id_in, size_t key_size_in, size_t nonce_size_in);
|
||||
};
|
||||
|
||||
struct Context
|
||||
{
|
||||
bytes do_export(const bytes& exporter_context, size_t size) const;
|
||||
|
||||
protected:
|
||||
bytes suite;
|
||||
bytes key;
|
||||
bytes nonce;
|
||||
bytes exporter_secret;
|
||||
const KDF& kdf;
|
||||
const AEAD& aead;
|
||||
|
||||
bytes current_nonce() const;
|
||||
void increment_seq();
|
||||
|
||||
private:
|
||||
uint64_t seq;
|
||||
|
||||
Context(bytes suite_in,
|
||||
bytes key_in,
|
||||
bytes nonce_in,
|
||||
bytes exporter_secret_in,
|
||||
const KDF& kdf_in,
|
||||
const AEAD& aead_in);
|
||||
|
||||
friend struct HPKE;
|
||||
friend struct HPKETest;
|
||||
friend bool operator==(const Context& lhs, const Context& rhs);
|
||||
};
|
||||
|
||||
struct SenderContext : public Context
|
||||
{
|
||||
SenderContext(Context&& c);
|
||||
bytes seal(const bytes& aad, const bytes& pt);
|
||||
};
|
||||
|
||||
struct ReceiverContext : public Context
|
||||
{
|
||||
ReceiverContext(Context&& c);
|
||||
std::optional<bytes> open(const bytes& aad, const bytes& ct);
|
||||
};
|
||||
|
||||
struct HPKE
|
||||
{
|
||||
enum struct Mode : uint8_t
|
||||
{
|
||||
base = 0,
|
||||
psk = 1,
|
||||
auth = 2,
|
||||
auth_psk = 3,
|
||||
};
|
||||
|
||||
HPKE(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id);
|
||||
|
||||
using SenderInfo = std::pair<bytes, SenderContext>;
|
||||
|
||||
SenderInfo setup_base_s(const KEM::PublicKey& pkR, const bytes& info) const;
|
||||
ReceiverContext setup_base_r(const bytes& enc,
|
||||
const KEM::PrivateKey& skR,
|
||||
const bytes& info) const;
|
||||
|
||||
SenderInfo setup_psk_s(const KEM::PublicKey& pkR,
|
||||
const bytes& info,
|
||||
const bytes& psk,
|
||||
const bytes& psk_id) const;
|
||||
ReceiverContext setup_psk_r(const bytes& enc,
|
||||
const KEM::PrivateKey& skR,
|
||||
const bytes& info,
|
||||
const bytes& psk,
|
||||
const bytes& psk_id) const;
|
||||
|
||||
SenderInfo setup_auth_s(const KEM::PublicKey& pkR,
|
||||
const bytes& info,
|
||||
const KEM::PrivateKey& skS) const;
|
||||
ReceiverContext setup_auth_r(const bytes& enc,
|
||||
const KEM::PrivateKey& skR,
|
||||
const bytes& info,
|
||||
const KEM::PublicKey& pkS) const;
|
||||
|
||||
SenderInfo setup_auth_psk_s(const KEM::PublicKey& pkR,
|
||||
const bytes& info,
|
||||
const bytes& psk,
|
||||
const bytes& psk_id,
|
||||
const KEM::PrivateKey& skS) const;
|
||||
ReceiverContext setup_auth_psk_r(const bytes& enc,
|
||||
const KEM::PrivateKey& skR,
|
||||
const bytes& info,
|
||||
const bytes& psk,
|
||||
const bytes& psk_id,
|
||||
const KEM::PublicKey& pkS) const;
|
||||
|
||||
bytes suite;
|
||||
const KEM& kem;
|
||||
const KDF& kdf;
|
||||
const AEAD& aead;
|
||||
|
||||
private:
|
||||
static bool verify_psk_inputs(Mode mode,
|
||||
const bytes& psk,
|
||||
const bytes& psk_id);
|
||||
Context key_schedule(Mode mode,
|
||||
const bytes& shared_secret,
|
||||
const bytes& info,
|
||||
const bytes& psk,
|
||||
const bytes& psk_id) const;
|
||||
};
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
11
DPP/mlspp/lib/hpke/include/hpke/random.h
Executable file
11
DPP/mlspp/lib/hpke/include/hpke/random.h
Executable file
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <bytes/bytes.h>
|
||||
using namespace mlspp::bytes_ns;
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
bytes
|
||||
random_bytes(size_t size);
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
89
DPP/mlspp/lib/hpke/include/hpke/signature.h
Executable file
89
DPP/mlspp/lib/hpke/include/hpke/signature.h
Executable file
@@ -0,0 +1,89 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <bytes/bytes.h>
|
||||
using namespace mlspp::bytes_ns;
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
struct Signature
|
||||
{
|
||||
enum struct ID
|
||||
{
|
||||
P256_SHA256,
|
||||
P384_SHA384,
|
||||
P521_SHA512,
|
||||
Ed25519,
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
Ed448,
|
||||
#endif
|
||||
RSA_SHA256,
|
||||
RSA_SHA384,
|
||||
RSA_SHA512,
|
||||
};
|
||||
|
||||
template<Signature::ID id>
|
||||
static const Signature& get();
|
||||
|
||||
virtual ~Signature() = default;
|
||||
|
||||
struct PublicKey
|
||||
{
|
||||
virtual ~PublicKey() = default;
|
||||
};
|
||||
|
||||
struct PrivateKey
|
||||
{
|
||||
virtual ~PrivateKey() = default;
|
||||
virtual std::unique_ptr<PublicKey> public_key() const = 0;
|
||||
};
|
||||
|
||||
const ID id;
|
||||
|
||||
virtual std::unique_ptr<PrivateKey> generate_key_pair() const = 0;
|
||||
virtual std::unique_ptr<PrivateKey> derive_key_pair(
|
||||
const bytes& ikm) const = 0;
|
||||
|
||||
virtual bytes serialize(const PublicKey& pk) const = 0;
|
||||
virtual std::unique_ptr<PublicKey> deserialize(const bytes& enc) const = 0;
|
||||
|
||||
virtual bytes serialize_private(const PrivateKey& sk) const = 0;
|
||||
virtual std::unique_ptr<PrivateKey> deserialize_private(
|
||||
const bytes& skm) const = 0;
|
||||
|
||||
struct PrivateJWK
|
||||
{
|
||||
const Signature& sig;
|
||||
std::optional<std::string> key_id;
|
||||
std::unique_ptr<PrivateKey> key;
|
||||
};
|
||||
static PrivateJWK parse_jwk_private(const std::string& jwk_json);
|
||||
|
||||
struct PublicJWK
|
||||
{
|
||||
const Signature& sig;
|
||||
std::optional<std::string> key_id;
|
||||
std::unique_ptr<PublicKey> key;
|
||||
};
|
||||
static PublicJWK parse_jwk(const std::string& jwk_json);
|
||||
|
||||
virtual std::unique_ptr<PrivateKey> import_jwk_private(
|
||||
const std::string& jwk_json) const = 0;
|
||||
virtual std::unique_ptr<PublicKey> import_jwk(
|
||||
const std::string& jwk_json) const = 0;
|
||||
virtual std::string export_jwk_private(const PrivateKey& env) const = 0;
|
||||
virtual std::string export_jwk(const PublicKey& env) const = 0;
|
||||
|
||||
virtual bytes sign(const bytes& data, const PrivateKey& sk) const = 0;
|
||||
virtual bool verify(const bytes& data,
|
||||
const bytes& sig,
|
||||
const PublicKey& pk) const = 0;
|
||||
|
||||
static std::unique_ptr<PrivateKey> generate_rsa(size_t bits);
|
||||
|
||||
protected:
|
||||
Signature(ID id_in);
|
||||
};
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
82
DPP/mlspp/lib/hpke/include/hpke/userinfo_vc.h
Executable file
82
DPP/mlspp/lib/hpke/include/hpke/userinfo_vc.h
Executable file
@@ -0,0 +1,82 @@
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
#include <bytes/bytes.h>
|
||||
#include <chrono>
|
||||
#include <hpke/signature.h>
|
||||
#include <map>
|
||||
|
||||
using namespace mlspp::bytes_ns;
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
struct UserInfoClaimsAddress
|
||||
{
|
||||
std::optional<std::string> formatted;
|
||||
std::optional<std::string> street_address;
|
||||
std::optional<std::string> locality;
|
||||
std::optional<std::string> region;
|
||||
std::optional<std::string> postal_code;
|
||||
std::optional<std::string> country;
|
||||
};
|
||||
|
||||
struct UserInfoClaims
|
||||
{
|
||||
|
||||
std::optional<std::string> sub;
|
||||
std::optional<std::string> name;
|
||||
std::optional<std::string> given_name;
|
||||
std::optional<std::string> family_name;
|
||||
std::optional<std::string> middle_name;
|
||||
std::optional<std::string> nickname;
|
||||
std::optional<std::string> preferred_username;
|
||||
std::optional<std::string> profile;
|
||||
std::optional<std::string> picture;
|
||||
std::optional<std::string> website;
|
||||
std::optional<std::string> email;
|
||||
std::optional<bool> email_verified;
|
||||
std::optional<std::string> gender;
|
||||
std::optional<std::string> birthdate;
|
||||
std::optional<std::string> zoneinfo;
|
||||
std::optional<std::string> locale;
|
||||
std::optional<std::string> phone_number;
|
||||
std::optional<bool> phone_number_verified;
|
||||
std::optional<UserInfoClaimsAddress> address;
|
||||
std::optional<uint64_t> updated_at;
|
||||
|
||||
static UserInfoClaims from_json(const std::string& cred_subject);
|
||||
};
|
||||
|
||||
struct UserInfoVC
|
||||
{
|
||||
private:
|
||||
struct ParsedCredential;
|
||||
std::shared_ptr<ParsedCredential> parsed_cred;
|
||||
|
||||
public:
|
||||
explicit UserInfoVC(std::string jwt);
|
||||
UserInfoVC() = default;
|
||||
UserInfoVC(const UserInfoVC& other) = default;
|
||||
~UserInfoVC() = default;
|
||||
UserInfoVC& operator=(const UserInfoVC& other) = default;
|
||||
UserInfoVC& operator=(UserInfoVC&& other) = default;
|
||||
|
||||
const Signature& signature_algorithm() const;
|
||||
std::string issuer() const;
|
||||
std::optional<std::string> key_id() const;
|
||||
std::chrono::system_clock::time_point not_before() const;
|
||||
std::chrono::system_clock::time_point not_after() const;
|
||||
const std::string& raw_credential() const;
|
||||
const UserInfoClaims& subject() const;
|
||||
const Signature::PublicJWK& public_key() const;
|
||||
|
||||
bool valid_from(const Signature::PublicKey& issuer_key) const;
|
||||
|
||||
std::string raw;
|
||||
};
|
||||
|
||||
bool
|
||||
operator==(const UserInfoVC& lhs, const UserInfoVC& rhs);
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
321
DPP/mlspp/lib/hpke/src/aead_cipher.cpp
Executable file
321
DPP/mlspp/lib/hpke/src/aead_cipher.cpp
Executable file
@@ -0,0 +1,321 @@
|
||||
#include "aead_cipher.h"
|
||||
#include "openssl_common.h"
|
||||
|
||||
#include <openssl/evp.h>
|
||||
|
||||
#if WITH_BORINGSSL
|
||||
#include <openssl/aead.h>
|
||||
#endif
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
///
|
||||
/// ExportOnlyCipher
|
||||
///
|
||||
bytes
|
||||
ExportOnlyCipher::seal(const bytes& /* key */,
|
||||
const bytes& /* nonce */,
|
||||
const bytes& /* aad */,
|
||||
const bytes& /* pt */) const
|
||||
{
|
||||
throw std::runtime_error("seal() on export-only context");
|
||||
}
|
||||
|
||||
std::optional<bytes>
|
||||
ExportOnlyCipher::open(const bytes& /* key */,
|
||||
const bytes& /* nonce */,
|
||||
const bytes& /* aad */,
|
||||
const bytes& /* ct */) const
|
||||
{
|
||||
throw std::runtime_error("open() on export-only context");
|
||||
}
|
||||
|
||||
ExportOnlyCipher::ExportOnlyCipher()
|
||||
: AEAD(AEAD::ID::export_only, 0, 0)
|
||||
{
|
||||
}
|
||||
|
||||
///
|
||||
/// AEADCipher
|
||||
///
|
||||
AEADCipher
|
||||
make_aead(AEAD::ID cipher_in)
|
||||
{
|
||||
return { cipher_in };
|
||||
}
|
||||
|
||||
template<>
|
||||
const AEADCipher&
|
||||
AEADCipher::get<AEAD::ID::AES_128_GCM>()
|
||||
{
|
||||
static const auto instance = make_aead(AEAD::ID::AES_128_GCM);
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const AEADCipher&
|
||||
AEADCipher::get<AEAD::ID::AES_256_GCM>()
|
||||
{
|
||||
static const auto instance = make_aead(AEAD::ID::AES_256_GCM);
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const AEADCipher&
|
||||
AEADCipher::get<AEAD::ID::CHACHA20_POLY1305>()
|
||||
{
|
||||
static const auto instance = make_aead(AEAD::ID::CHACHA20_POLY1305);
|
||||
return instance;
|
||||
}
|
||||
|
||||
static size_t
|
||||
cipher_key_size(AEAD::ID cipher)
|
||||
{
|
||||
switch (cipher) {
|
||||
case AEAD::ID::AES_128_GCM:
|
||||
return 16;
|
||||
|
||||
case AEAD::ID::AES_256_GCM:
|
||||
case AEAD::ID::CHACHA20_POLY1305:
|
||||
return 32;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
|
||||
static size_t
|
||||
cipher_nonce_size(AEAD::ID cipher)
|
||||
{
|
||||
switch (cipher) {
|
||||
case AEAD::ID::AES_128_GCM:
|
||||
case AEAD::ID::AES_256_GCM:
|
||||
case AEAD::ID::CHACHA20_POLY1305:
|
||||
return 12;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
|
||||
static size_t
|
||||
cipher_tag_size(AEAD::ID cipher)
|
||||
{
|
||||
switch (cipher) {
|
||||
case AEAD::ID::AES_128_GCM:
|
||||
case AEAD::ID::AES_256_GCM:
|
||||
case AEAD::ID::CHACHA20_POLY1305:
|
||||
return 16;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
|
||||
#if WITH_BORINGSSL
|
||||
static const EVP_AEAD*
|
||||
boringssl_cipher(AEAD::ID cipher)
|
||||
{
|
||||
switch (cipher) {
|
||||
case AEAD::ID::AES_128_GCM:
|
||||
return EVP_aead_aes_128_gcm();
|
||||
|
||||
case AEAD::ID::AES_256_GCM:
|
||||
return EVP_aead_aes_256_gcm();
|
||||
|
||||
case AEAD::ID::CHACHA20_POLY1305:
|
||||
return EVP_aead_chacha20_poly1305();
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
#else
|
||||
static const EVP_CIPHER*
|
||||
openssl_cipher(AEAD::ID cipher)
|
||||
{
|
||||
switch (cipher) {
|
||||
case AEAD::ID::AES_128_GCM:
|
||||
return EVP_aes_128_gcm();
|
||||
|
||||
case AEAD::ID::AES_256_GCM:
|
||||
return EVP_aes_256_gcm();
|
||||
|
||||
case AEAD::ID::CHACHA20_POLY1305:
|
||||
return EVP_chacha20_poly1305();
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
#endif // WITH_BORINGSSL
|
||||
|
||||
AEADCipher::AEADCipher(AEAD::ID id_in)
|
||||
: AEAD(id_in, cipher_key_size(id_in), cipher_nonce_size(id_in))
|
||||
, tag_size(cipher_tag_size(id))
|
||||
{
|
||||
}
|
||||
|
||||
bytes
|
||||
AEADCipher::seal(const bytes& key,
|
||||
const bytes& nonce,
|
||||
const bytes& aad,
|
||||
const bytes& pt) const
|
||||
{
|
||||
#if WITH_BORINGSSL
|
||||
auto ctx = make_typed_unique(
|
||||
EVP_AEAD_CTX_new(boringssl_cipher(id), key.data(), key.size(), tag_size));
|
||||
if (ctx == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
auto ct = bytes(pt.size() + tag_size);
|
||||
auto out_len = ct.size();
|
||||
if (1 != EVP_AEAD_CTX_seal(ctx.get(),
|
||||
ct.data(),
|
||||
&out_len,
|
||||
ct.size(),
|
||||
nonce.data(),
|
||||
nonce.size(),
|
||||
pt.data(),
|
||||
pt.size(),
|
||||
aad.data(),
|
||||
aad.size())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
return ct;
|
||||
#else
|
||||
auto ctx = make_typed_unique(EVP_CIPHER_CTX_new());
|
||||
if (ctx == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
const auto* cipher = openssl_cipher(id);
|
||||
if (1 != EVP_EncryptInit(ctx.get(), cipher, key.data(), nonce.data())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
int outlen = 0;
|
||||
if (!aad.empty()) {
|
||||
if (1 != EVP_EncryptUpdate(ctx.get(),
|
||||
nullptr,
|
||||
&outlen,
|
||||
aad.data(),
|
||||
static_cast<int>(aad.size()))) {
|
||||
throw openssl_error();
|
||||
}
|
||||
}
|
||||
|
||||
bytes ct(pt.size());
|
||||
if (1 != EVP_EncryptUpdate(ctx.get(),
|
||||
ct.data(),
|
||||
&outlen,
|
||||
pt.data(),
|
||||
static_cast<int>(pt.size()))) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
// Providing nullptr as an argument is safe here because this
|
||||
// function never writes with GCM; it only computes the tag
|
||||
if (1 != EVP_EncryptFinal(ctx.get(), nullptr, &outlen)) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
bytes tag(tag_size);
|
||||
if (1 != EVP_CIPHER_CTX_ctrl(ctx.get(),
|
||||
EVP_CTRL_GCM_GET_TAG,
|
||||
static_cast<int>(tag_size),
|
||||
tag.data())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
ct += tag;
|
||||
return ct;
|
||||
#endif // WITH_BORINGSSL
|
||||
}
|
||||
|
||||
std::optional<bytes>
|
||||
AEADCipher::open(const bytes& key,
|
||||
const bytes& nonce,
|
||||
const bytes& aad,
|
||||
const bytes& ct) const
|
||||
{
|
||||
if (ct.size() < tag_size) {
|
||||
throw std::runtime_error("AEAD ciphertext smaller than tag size");
|
||||
}
|
||||
|
||||
#if WITH_BORINGSSL
|
||||
auto ctx = make_typed_unique(EVP_AEAD_CTX_new(
|
||||
boringssl_cipher(id), key.data(), key.size(), cipher_tag_size(id)));
|
||||
if (ctx == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
auto pt = bytes(ct.size() - tag_size);
|
||||
auto out_len = pt.size();
|
||||
if (1 != EVP_AEAD_CTX_open(ctx.get(),
|
||||
pt.data(),
|
||||
&out_len,
|
||||
pt.size(),
|
||||
nonce.data(),
|
||||
nonce.size(),
|
||||
ct.data(),
|
||||
ct.size(),
|
||||
aad.data(),
|
||||
aad.size())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
return pt;
|
||||
#else
|
||||
auto ctx = make_typed_unique(EVP_CIPHER_CTX_new());
|
||||
if (ctx == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
const auto* cipher = openssl_cipher(id);
|
||||
if (1 != EVP_DecryptInit(ctx.get(), cipher, key.data(), nonce.data())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
auto inner_ct_size = ct.size() - tag_size;
|
||||
auto tag = ct.slice(inner_ct_size, ct.size());
|
||||
if (1 != EVP_CIPHER_CTX_ctrl(ctx.get(),
|
||||
EVP_CTRL_GCM_SET_TAG,
|
||||
static_cast<int>(tag_size),
|
||||
tag.data())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
int out_size = 0;
|
||||
if (!aad.empty()) {
|
||||
if (1 != EVP_DecryptUpdate(ctx.get(),
|
||||
nullptr,
|
||||
&out_size,
|
||||
aad.data(),
|
||||
static_cast<int>(aad.size()))) {
|
||||
throw openssl_error();
|
||||
}
|
||||
}
|
||||
|
||||
bytes pt(inner_ct_size);
|
||||
if (1 != EVP_DecryptUpdate(ctx.get(),
|
||||
pt.data(),
|
||||
&out_size,
|
||||
ct.data(),
|
||||
static_cast<int>(inner_ct_size))) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
// Providing nullptr as an argument is safe here because this
|
||||
// function never writes with GCM; it only verifies the tag
|
||||
if (1 != EVP_DecryptFinal(ctx.get(), nullptr, &out_size)) {
|
||||
throw std::runtime_error("AEAD authentication failure");
|
||||
}
|
||||
|
||||
return pt;
|
||||
#endif // WITH_BORINGSSL
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
45
DPP/mlspp/lib/hpke/src/aead_cipher.h
Executable file
45
DPP/mlspp/lib/hpke/src/aead_cipher.h
Executable file
@@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#include <hpke/hpke.h>
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
struct ExportOnlyCipher : public AEAD
|
||||
{
|
||||
ExportOnlyCipher();
|
||||
~ExportOnlyCipher() override = default;
|
||||
|
||||
bytes seal(const bytes& key,
|
||||
const bytes& nonce,
|
||||
const bytes& aad,
|
||||
const bytes& pt) const override;
|
||||
std::optional<bytes> open(const bytes& key,
|
||||
const bytes& nonce,
|
||||
const bytes& aad,
|
||||
const bytes& ct) const override;
|
||||
};
|
||||
|
||||
struct AEADCipher : public AEAD
|
||||
{
|
||||
template<AEAD::ID id>
|
||||
static const AEADCipher& get();
|
||||
|
||||
~AEADCipher() override = default;
|
||||
|
||||
bytes seal(const bytes& key,
|
||||
const bytes& nonce,
|
||||
const bytes& aad,
|
||||
const bytes& pt) const override;
|
||||
std::optional<bytes> open(const bytes& key,
|
||||
const bytes& nonce,
|
||||
const bytes& aad,
|
||||
const bytes& ct) const override;
|
||||
|
||||
private:
|
||||
const size_t tag_size;
|
||||
|
||||
AEADCipher(AEAD::ID id_in);
|
||||
friend AEADCipher make_aead(AEAD::ID cipher_in);
|
||||
};
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
105
DPP/mlspp/lib/hpke/src/base64.cpp
Executable file
105
DPP/mlspp/lib/hpke/src/base64.cpp
Executable file
@@ -0,0 +1,105 @@
|
||||
#include <hpke/base64.h>
|
||||
|
||||
#include "openssl_common.h"
|
||||
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/evp.h>
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
std::string
|
||||
to_base64(const bytes& data)
|
||||
{
|
||||
if (data.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
#if WITH_BORINGSSL
|
||||
const auto data_size = data.size();
|
||||
#else
|
||||
const auto data_size = static_cast<int>(data.size());
|
||||
#endif
|
||||
|
||||
// base64 encoding produces 4 characters for every 3 input bytes (rounded up)
|
||||
const auto out_size = (data_size + 2) / 3 * 4;
|
||||
auto out = bytes(out_size + 1); // NUL terminator
|
||||
|
||||
const auto result = EVP_EncodeBlock(out.data(), data.data(), data_size);
|
||||
if (result != out_size) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
out.resize(out.size() - 1); // strip NUL terminator
|
||||
return to_ascii(out);
|
||||
}
|
||||
|
||||
std::string
|
||||
to_base64url(const bytes& data)
|
||||
{
|
||||
if (data.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
auto encoded = to_base64(data);
|
||||
|
||||
auto pad_start = encoded.find_first_of('=');
|
||||
if (pad_start != std::string::npos) {
|
||||
encoded = encoded.substr(0, pad_start);
|
||||
}
|
||||
|
||||
std::replace(encoded.begin(), encoded.end(), '+', '-');
|
||||
std::replace(encoded.begin(), encoded.end(), '/', '_');
|
||||
|
||||
return encoded;
|
||||
}
|
||||
|
||||
bytes
|
||||
from_base64(const std::string& enc)
|
||||
{
|
||||
if (enc.length() == 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
if (enc.length() % 4 != 0) {
|
||||
throw std::runtime_error("Base64 length is not divisible by 4");
|
||||
}
|
||||
|
||||
const auto in = from_ascii(enc);
|
||||
const auto in_size = static_cast<int>(in.size());
|
||||
const auto out_size = in_size / 4 * 3;
|
||||
auto out = bytes(out_size);
|
||||
|
||||
const auto result = EVP_DecodeBlock(out.data(), in.data(), in_size);
|
||||
if (result != out_size) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
if (enc.substr(enc.length() - 2, enc.length()) == "==") {
|
||||
out.resize(out.size() - 2);
|
||||
} else if (enc.substr(enc.length() - 1, enc.length()) == "=") {
|
||||
out.resize(out.size() - 1);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
bytes
|
||||
from_base64url(const std::string& enc)
|
||||
{
|
||||
if (enc.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto enc_copy = enc;
|
||||
std::replace(enc_copy.begin(), enc_copy.end(), '-', '+');
|
||||
std::replace(enc_copy.begin(), enc_copy.end(), '_', '/');
|
||||
|
||||
while (enc_copy.length() % 4 != 0) {
|
||||
enc_copy += "=";
|
||||
}
|
||||
|
||||
return from_base64(enc_copy);
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
539
DPP/mlspp/lib/hpke/src/certificate.cpp
Executable file
539
DPP/mlspp/lib/hpke/src/certificate.cpp
Executable file
@@ -0,0 +1,539 @@
|
||||
#include "group.h"
|
||||
#include "openssl_common.h"
|
||||
#include "rsa.h"
|
||||
#include <hpke/certificate.h>
|
||||
#include <hpke/signature.h>
|
||||
#include <memory>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/pem.h>
|
||||
#include <openssl/x509.h>
|
||||
#include <openssl/x509v3.h>
|
||||
#include <tls/compat.h>
|
||||
|
||||
namespace mlspp::hpke {
|
||||
///
|
||||
/// Utility functions
|
||||
///
|
||||
|
||||
static std::optional<bytes>
|
||||
asn1_octet_string_to_bytes(const ASN1_OCTET_STRING* octets)
|
||||
{
|
||||
if (octets == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
const auto* ptr = ASN1_STRING_get0_data(octets);
|
||||
const auto len = ASN1_STRING_length(octets);
|
||||
// NOLINTNEXTLINE (cppcoreguidelines-pro-bounds-pointer-arithmetic)
|
||||
return std::vector<uint8_t>(ptr, ptr + len);
|
||||
}
|
||||
|
||||
static std::string
|
||||
asn1_string_to_std_string(const ASN1_STRING* asn1_string)
|
||||
{
|
||||
const auto* data = ASN1_STRING_get0_data(asn1_string);
|
||||
const auto data_size = static_cast<size_t>(ASN1_STRING_length(asn1_string));
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
|
||||
auto str = std::string(reinterpret_cast<const char*>(data));
|
||||
if (str.size() != data_size) {
|
||||
throw std::runtime_error("Malformed ASN.1 string");
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
static std::chrono::system_clock::time_point
|
||||
asn1_time_to_chrono(const ASN1_TIME* asn1_time)
|
||||
{
|
||||
auto epoch_chrono = std::chrono::system_clock::time_point();
|
||||
auto epoch_time_t = std::chrono::system_clock::to_time_t(epoch_chrono);
|
||||
auto epoch_asn1 = make_typed_unique(ASN1_TIME_set(nullptr, epoch_time_t));
|
||||
if (!epoch_asn1) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
auto secs = int(0);
|
||||
auto days = int(0);
|
||||
if (ASN1_TIME_diff(&days, &secs, epoch_asn1.get(), asn1_time) != 1) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
auto delta = std::chrono::seconds(secs) + std::chrono::hours(24 * days);
|
||||
return std::chrono::system_clock::time_point(delta);
|
||||
}
|
||||
|
||||
///
|
||||
/// ParsedCertificate
|
||||
///
|
||||
|
||||
const int Certificate::NameType::organization = NID_organizationName;
|
||||
const int Certificate::NameType::common_name = NID_commonName;
|
||||
const int Certificate::NameType::organizational_unit =
|
||||
NID_organizationalUnitName;
|
||||
const int Certificate::NameType::country = NID_countryName;
|
||||
const int Certificate::NameType::serial_number = NID_serialNumber;
|
||||
const int Certificate::NameType::state_or_province_name =
|
||||
NID_stateOrProvinceName;
|
||||
|
||||
struct RFC822Name
|
||||
{
|
||||
std::string value;
|
||||
};
|
||||
|
||||
struct DNSName
|
||||
{
|
||||
std::string value;
|
||||
};
|
||||
|
||||
using GeneralName = tls::var::variant<RFC822Name, DNSName>;
|
||||
|
||||
struct Certificate::ParsedCertificate
|
||||
{
|
||||
|
||||
static std::unique_ptr<ParsedCertificate> parse(const bytes& der)
|
||||
{
|
||||
const auto* buf = der.data();
|
||||
auto cert =
|
||||
make_typed_unique(d2i_X509(nullptr, &buf, static_cast<int>(der.size())));
|
||||
if (cert == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
return std::make_unique<ParsedCertificate>(cert.release());
|
||||
}
|
||||
|
||||
static bytes compute_digest(const X509* cert)
|
||||
{
|
||||
const auto* md = EVP_sha256();
|
||||
auto digest = bytes(EVP_MD_size(md));
|
||||
unsigned int out_size = 0;
|
||||
if (1 != X509_digest(cert, md, digest.data(), &out_size)) {
|
||||
throw openssl_error();
|
||||
}
|
||||
return digest;
|
||||
}
|
||||
|
||||
// Note: This method does not implement total general name parsing.
|
||||
// Duplicate entries are not supported; if they are present, the last one
|
||||
// presented by OpenSSL is chosen.
|
||||
static ParsedName parse_names(const X509_NAME* x509_name)
|
||||
{
|
||||
if (x509_name == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
ParsedName parsed_name;
|
||||
|
||||
for (int i = X509_NAME_entry_count(x509_name) - 1; i >= 0; i--) {
|
||||
auto* entry = X509_NAME_get_entry(x509_name, i);
|
||||
if (entry == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* oid = X509_NAME_ENTRY_get_object(entry);
|
||||
auto* asn_str = X509_NAME_ENTRY_get_data(entry);
|
||||
if (oid == nullptr || asn_str == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int nid = OBJ_obj2nid(oid);
|
||||
const std::string parsed_value = asn1_string_to_std_string(asn_str);
|
||||
parsed_name[nid] = parsed_value;
|
||||
}
|
||||
|
||||
return parsed_name;
|
||||
}
|
||||
|
||||
// Parse Subject Key Identifier Extension
|
||||
static std::optional<bytes> parse_skid(X509* cert)
|
||||
{
|
||||
return asn1_octet_string_to_bytes(X509_get0_subject_key_id(cert));
|
||||
}
|
||||
|
||||
// Parse Authority Key Identifier
|
||||
static std::optional<bytes> parse_akid(X509* cert)
|
||||
{
|
||||
return asn1_octet_string_to_bytes(X509_get0_authority_key_id(cert));
|
||||
}
|
||||
|
||||
static std::vector<GeneralName> parse_san(X509* cert)
|
||||
{
|
||||
std::vector<GeneralName> names;
|
||||
|
||||
#ifdef WITH_BORINGSSL
|
||||
using san_names_nb_t = size_t;
|
||||
#else
|
||||
using san_names_nb_t = int;
|
||||
#endif
|
||||
|
||||
san_names_nb_t san_names_nb = 0;
|
||||
|
||||
auto* ext_ptr =
|
||||
X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
|
||||
auto* san_ptr = reinterpret_cast<STACK_OF(GENERAL_NAME)*>(ext_ptr);
|
||||
const auto san_names = make_typed_unique(san_ptr);
|
||||
san_names_nb = sk_GENERAL_NAME_num(san_names.get());
|
||||
|
||||
// Check each name within the extension
|
||||
for (san_names_nb_t i = 0; i < san_names_nb; i++) {
|
||||
auto* current_name = sk_GENERAL_NAME_value(san_names.get(), i);
|
||||
if (current_name->type == GEN_DNS) {
|
||||
const auto dns_name =
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access)
|
||||
asn1_string_to_std_string(current_name->d.dNSName);
|
||||
names.emplace_back(DNSName{ dns_name });
|
||||
} else if (current_name->type == GEN_EMAIL) {
|
||||
const auto email =
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access
|
||||
asn1_string_to_std_string(current_name->d.rfc822Name);
|
||||
names.emplace_back(RFC822Name{ email });
|
||||
}
|
||||
}
|
||||
|
||||
return names;
|
||||
}
|
||||
|
||||
explicit ParsedCertificate(X509* x509_in)
|
||||
: x509(x509_in, typed_delete<X509>)
|
||||
, pub_key_id(public_key_algorithm(x509.get()))
|
||||
, sig_algo(signature_algorithm(x509.get()))
|
||||
, issuer_hash(X509_issuer_name_hash(x509.get()))
|
||||
, subject_hash(X509_subject_name_hash(x509.get()))
|
||||
, issuer(parse_names(X509_get_issuer_name(x509.get())))
|
||||
, subject(parse_names(X509_get_subject_name(x509.get())))
|
||||
, subject_key_id(parse_skid(x509.get()))
|
||||
, authority_key_id(parse_akid(x509.get()))
|
||||
, sub_alt_names(parse_san(x509.get()))
|
||||
, is_ca(X509_check_ca(x509.get()) != 0)
|
||||
, hash(compute_digest(x509.get()))
|
||||
, not_before(asn1_time_to_chrono(X509_get0_notBefore(x509.get())))
|
||||
, not_after(asn1_time_to_chrono(X509_get0_notAfter(x509.get())))
|
||||
{
|
||||
}
|
||||
|
||||
ParsedCertificate(const ParsedCertificate& other)
|
||||
: x509(nullptr, typed_delete<X509>)
|
||||
, pub_key_id(public_key_algorithm(other.x509.get()))
|
||||
, sig_algo(signature_algorithm(other.x509.get()))
|
||||
, issuer_hash(other.issuer_hash)
|
||||
, subject_hash(other.subject_hash)
|
||||
, issuer(other.issuer)
|
||||
, subject(other.subject)
|
||||
, subject_key_id(other.subject_key_id)
|
||||
, authority_key_id(other.authority_key_id)
|
||||
, sub_alt_names(other.sub_alt_names)
|
||||
, is_ca(other.is_ca)
|
||||
, hash(other.hash)
|
||||
, not_before(other.not_before)
|
||||
, not_after(other.not_after)
|
||||
{
|
||||
if (1 != X509_up_ref(other.x509.get())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
x509.reset(other.x509.get());
|
||||
}
|
||||
|
||||
static Signature::ID public_key_algorithm(X509* x509)
|
||||
{
|
||||
#if WITH_BORINGSSL
|
||||
const auto pub = make_typed_unique(X509_get_pubkey(x509));
|
||||
const auto* pub_ptr = pub.get();
|
||||
#else
|
||||
const auto* pub_ptr = X509_get0_pubkey(x509);
|
||||
#endif
|
||||
|
||||
switch (EVP_PKEY_base_id(pub_ptr)) {
|
||||
case EVP_PKEY_ED25519:
|
||||
return Signature::ID::Ed25519;
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
case EVP_PKEY_ED448:
|
||||
return Signature::ID::Ed448;
|
||||
#endif
|
||||
case EVP_PKEY_EC: {
|
||||
auto key_size = EVP_PKEY_bits(pub_ptr);
|
||||
switch (key_size) {
|
||||
case 256:
|
||||
return Signature::ID::P256_SHA256;
|
||||
case 384:
|
||||
return Signature::ID::P384_SHA384;
|
||||
case 521:
|
||||
return Signature::ID::P521_SHA512;
|
||||
default:
|
||||
throw std::runtime_error("Unknown curve");
|
||||
}
|
||||
}
|
||||
case EVP_PKEY_RSA:
|
||||
// RSA public keys are not specific to an algorithm
|
||||
return Signature::ID::RSA_SHA256;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
throw std::runtime_error("Unsupported public key algorithm");
|
||||
}
|
||||
|
||||
static Signature::ID signature_algorithm(X509* cert)
|
||||
{
|
||||
auto nid = X509_get_signature_nid(cert);
|
||||
switch (nid) {
|
||||
case EVP_PKEY_ED25519:
|
||||
return Signature::ID::Ed25519;
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
case EVP_PKEY_ED448:
|
||||
return Signature::ID::Ed448;
|
||||
#endif
|
||||
case NID_ecdsa_with_SHA256:
|
||||
return Signature::ID::P256_SHA256;
|
||||
case NID_ecdsa_with_SHA384:
|
||||
return Signature::ID::P384_SHA384;
|
||||
case NID_ecdsa_with_SHA512:
|
||||
return Signature::ID::P521_SHA512;
|
||||
case NID_sha1WithRSAEncryption:
|
||||
// We fall through to SHA256 for SHA1 because we do not implement SHA-1.
|
||||
case NID_sha256WithRSAEncryption:
|
||||
return Signature::ID::RSA_SHA256;
|
||||
case NID_sha384WithRSAEncryption:
|
||||
return Signature::ID::RSA_SHA384;
|
||||
case NID_sha512WithRSAEncryption:
|
||||
return Signature::ID::RSA_SHA512;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unsupported signature algorithm");
|
||||
}
|
||||
|
||||
typed_unique_ptr<EVP_PKEY> public_key() const
|
||||
{
|
||||
return make_typed_unique<EVP_PKEY>(X509_get_pubkey(x509.get()));
|
||||
}
|
||||
|
||||
Certificate::ExpirationStatus expiration_status() const
|
||||
{
|
||||
auto now = std::chrono::system_clock::now();
|
||||
|
||||
if (now < not_before) {
|
||||
return Certificate::ExpirationStatus::inactive;
|
||||
}
|
||||
|
||||
if (now > not_after) {
|
||||
return Certificate::ExpirationStatus::expired;
|
||||
}
|
||||
|
||||
return Certificate::ExpirationStatus::active;
|
||||
}
|
||||
|
||||
bytes raw() const
|
||||
{
|
||||
auto out = bytes(i2d_X509(x509.get(), nullptr));
|
||||
auto* ptr = out.data();
|
||||
i2d_X509(x509.get(), &ptr);
|
||||
return out;
|
||||
}
|
||||
|
||||
typed_unique_ptr<X509> x509;
|
||||
const Signature::ID pub_key_id;
|
||||
const Signature::ID sig_algo;
|
||||
const uint64_t issuer_hash;
|
||||
const uint64_t subject_hash;
|
||||
const ParsedName issuer;
|
||||
const ParsedName subject;
|
||||
const std::optional<bytes> subject_key_id;
|
||||
const std::optional<bytes> authority_key_id;
|
||||
const std::vector<GeneralName> sub_alt_names;
|
||||
const bool is_ca;
|
||||
const bytes hash;
|
||||
const std::chrono::system_clock::time_point not_before;
|
||||
const std::chrono::system_clock::time_point not_after;
|
||||
};
|
||||
|
||||
///
|
||||
/// Certificate
|
||||
///
|
||||
|
||||
static std::unique_ptr<Signature::PublicKey>
|
||||
signature_key(EVP_PKEY* pkey)
|
||||
{
|
||||
switch (EVP_PKEY_base_id(pkey)) {
|
||||
case EVP_PKEY_RSA:
|
||||
return std::make_unique<RSASignature::PublicKey>(pkey);
|
||||
|
||||
case EVP_PKEY_ED448:
|
||||
case EVP_PKEY_ED25519:
|
||||
case EVP_PKEY_EC:
|
||||
return std::make_unique<EVPGroup::PublicKey>(pkey);
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
|
||||
Certificate::Certificate(std::unique_ptr<ParsedCertificate>&& parsed_cert_in)
|
||||
: parsed_cert(std::move(parsed_cert_in))
|
||||
, public_key(signature_key(parsed_cert->public_key().release()))
|
||||
, raw(parsed_cert->raw())
|
||||
{
|
||||
}
|
||||
|
||||
Certificate::Certificate(const bytes& der)
|
||||
: parsed_cert(ParsedCertificate::parse(der))
|
||||
, public_key(signature_key(parsed_cert->public_key().release()))
|
||||
, raw(der)
|
||||
{
|
||||
}
|
||||
|
||||
Certificate::Certificate(const Certificate& other)
|
||||
: parsed_cert(std::make_unique<ParsedCertificate>(*other.parsed_cert))
|
||||
, public_key(signature_key(parsed_cert->public_key().release()))
|
||||
, raw(other.raw)
|
||||
{
|
||||
}
|
||||
|
||||
Certificate::~Certificate() = default;
|
||||
|
||||
std::vector<Certificate>
|
||||
Certificate::parse_pem(const bytes& pem)
|
||||
{
|
||||
auto size_int = static_cast<int>(pem.size());
|
||||
auto bio = make_typed_unique<BIO>(BIO_new_mem_buf(pem.data(), size_int));
|
||||
if (!bio) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
auto certs = std::vector<Certificate>();
|
||||
while (true) {
|
||||
auto x509 = make_typed_unique<X509>(
|
||||
PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
|
||||
if (!x509) {
|
||||
// NOLINTNEXTLINE(hicpp-signed-bitwise)
|
||||
auto err = ERR_GET_REASON(ERR_peek_last_error());
|
||||
if (err == PEM_R_NO_START_LINE) {
|
||||
// No more objects to read
|
||||
break;
|
||||
}
|
||||
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
auto parsed = std::make_unique<ParsedCertificate>(x509.release());
|
||||
certs.emplace_back(std::move(parsed));
|
||||
}
|
||||
|
||||
return certs;
|
||||
}
|
||||
|
||||
bool
|
||||
Certificate::valid_from(const Certificate& parent) const
|
||||
{
|
||||
auto pub = parent.parsed_cert->public_key();
|
||||
return (1 == X509_verify(parsed_cert->x509.get(), pub.get()));
|
||||
}
|
||||
|
||||
uint64_t
|
||||
Certificate::issuer_hash() const
|
||||
{
|
||||
return parsed_cert->issuer_hash;
|
||||
}
|
||||
|
||||
uint64_t
|
||||
Certificate::subject_hash() const
|
||||
{
|
||||
return parsed_cert->subject_hash;
|
||||
}
|
||||
|
||||
Certificate::ParsedName
|
||||
Certificate::subject() const
|
||||
{
|
||||
return parsed_cert->subject;
|
||||
}
|
||||
|
||||
Certificate::ParsedName
|
||||
Certificate::issuer() const
|
||||
{
|
||||
return parsed_cert->issuer;
|
||||
}
|
||||
|
||||
bool
|
||||
Certificate::is_ca() const
|
||||
{
|
||||
return parsed_cert->is_ca;
|
||||
}
|
||||
|
||||
Certificate::ExpirationStatus
|
||||
Certificate::expiration_status() const
|
||||
{
|
||||
return parsed_cert->expiration_status();
|
||||
}
|
||||
|
||||
std::optional<bytes>
|
||||
Certificate::subject_key_id() const
|
||||
{
|
||||
return parsed_cert->subject_key_id;
|
||||
}
|
||||
|
||||
std::optional<bytes>
|
||||
Certificate::authority_key_id() const
|
||||
{
|
||||
return parsed_cert->authority_key_id;
|
||||
}
|
||||
|
||||
std::vector<std::string>
|
||||
Certificate::email_addresses() const
|
||||
{
|
||||
std::vector<std::string> emails;
|
||||
for (const auto& name : parsed_cert->sub_alt_names) {
|
||||
if (tls::var::holds_alternative<RFC822Name>(name)) {
|
||||
emails.emplace_back(tls::var::get<RFC822Name>(name).value);
|
||||
}
|
||||
}
|
||||
return emails;
|
||||
}
|
||||
|
||||
std::vector<std::string>
|
||||
Certificate::dns_names() const
|
||||
{
|
||||
std::vector<std::string> domains;
|
||||
for (const auto& name : parsed_cert->sub_alt_names) {
|
||||
if (tls::var::holds_alternative<DNSName>(name)) {
|
||||
domains.emplace_back(tls::var::get<DNSName>(name).value);
|
||||
}
|
||||
}
|
||||
|
||||
return domains;
|
||||
}
|
||||
|
||||
bytes
|
||||
Certificate::hash() const
|
||||
{
|
||||
return parsed_cert->hash;
|
||||
}
|
||||
|
||||
std::chrono::system_clock::time_point
|
||||
Certificate::not_before() const
|
||||
{
|
||||
return parsed_cert->not_before;
|
||||
}
|
||||
|
||||
std::chrono::system_clock::time_point
|
||||
Certificate::not_after() const
|
||||
{
|
||||
return parsed_cert->not_after;
|
||||
}
|
||||
|
||||
Signature::ID
|
||||
Certificate::public_key_algorithm() const
|
||||
{
|
||||
return parsed_cert->pub_key_id;
|
||||
}
|
||||
|
||||
Signature::ID
|
||||
Certificate::signature_algorithm() const
|
||||
{
|
||||
return parsed_cert->sig_algo;
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const Certificate& lhs, const Certificate& rhs)
|
||||
{
|
||||
return lhs.raw == rhs.raw;
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
20
DPP/mlspp/lib/hpke/src/common.cpp
Executable file
20
DPP/mlspp/lib/hpke/src/common.cpp
Executable file
@@ -0,0 +1,20 @@
|
||||
#include "common.h"
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
bytes
|
||||
i2osp(uint64_t val, size_t size)
|
||||
{
|
||||
auto out = bytes(size, 0);
|
||||
auto max = size;
|
||||
if (size > 8) {
|
||||
max = 8;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < max; i++) {
|
||||
out.at(size - i - 1) = static_cast<uint8_t>(val >> (8 * i));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
10
DPP/mlspp/lib/hpke/src/common.h
Executable file
10
DPP/mlspp/lib/hpke/src/common.h
Executable file
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <hpke/hpke.h>
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
bytes
|
||||
i2osp(uint64_t val, size_t size);
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
216
DPP/mlspp/lib/hpke/src/dhkem.cpp
Executable file
216
DPP/mlspp/lib/hpke/src/dhkem.cpp
Executable file
@@ -0,0 +1,216 @@
|
||||
#include "dhkem.h"
|
||||
|
||||
#include "common.h"
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
DHKEM::PrivateKey::PrivateKey(Group::PrivateKey* group_priv_in)
|
||||
: group_priv(group_priv_in)
|
||||
{
|
||||
}
|
||||
|
||||
std::unique_ptr<KEM::PublicKey>
|
||||
DHKEM::PrivateKey::public_key() const
|
||||
{
|
||||
return group_priv->public_key();
|
||||
}
|
||||
|
||||
DHKEM
|
||||
make_dhkem(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in)
|
||||
{
|
||||
return { kem_id_in, group_in, kdf_in };
|
||||
}
|
||||
|
||||
template<>
|
||||
const DHKEM&
|
||||
DHKEM::get<KEM::ID::DHKEM_P256_SHA256>()
|
||||
{
|
||||
static const auto instance = make_dhkem(KEM::ID::DHKEM_P256_SHA256,
|
||||
Group::get<Group::ID::P256>(),
|
||||
KDF::get<KDF::ID::HKDF_SHA256>());
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const DHKEM&
|
||||
DHKEM::get<KEM::ID::DHKEM_P384_SHA384>()
|
||||
{
|
||||
static const auto instance = make_dhkem(KEM::ID::DHKEM_P384_SHA384,
|
||||
Group::get<Group::ID::P384>(),
|
||||
KDF::get<KDF::ID::HKDF_SHA384>());
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const DHKEM&
|
||||
DHKEM::get<KEM::ID::DHKEM_P521_SHA512>()
|
||||
{
|
||||
static const auto instance = make_dhkem(KEM::ID::DHKEM_P521_SHA512,
|
||||
Group::get<Group::ID::P521>(),
|
||||
KDF::get<KDF::ID::HKDF_SHA512>());
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const DHKEM&
|
||||
DHKEM::get<KEM::ID::DHKEM_X25519_SHA256>()
|
||||
{
|
||||
static const auto instance = make_dhkem(KEM::ID::DHKEM_X25519_SHA256,
|
||||
Group::get<Group::ID::X25519>(),
|
||||
KDF::get<KDF::ID::HKDF_SHA256>());
|
||||
return instance;
|
||||
}
|
||||
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
template<>
|
||||
const DHKEM&
|
||||
DHKEM::get<KEM::ID::DHKEM_X448_SHA512>()
|
||||
{
|
||||
static const auto instance = make_dhkem(KEM::ID::DHKEM_X448_SHA512,
|
||||
Group::get<Group::ID::X448>(),
|
||||
KDF::get<KDF::ID::HKDF_SHA512>());
|
||||
return instance;
|
||||
}
|
||||
#endif
|
||||
|
||||
DHKEM::DHKEM(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in)
|
||||
: KEM(kem_id_in,
|
||||
kdf_in.hash_size,
|
||||
group_in.pk_size,
|
||||
group_in.pk_size,
|
||||
group_in.sk_size)
|
||||
, group(group_in)
|
||||
, kdf(kdf_in)
|
||||
{
|
||||
static const auto label_kem = from_ascii("KEM");
|
||||
suite_id = label_kem + i2osp(uint16_t(kem_id_in), 2);
|
||||
}
|
||||
|
||||
std::unique_ptr<KEM::PrivateKey>
|
||||
DHKEM::generate_key_pair() const
|
||||
{
|
||||
return std::make_unique<DHKEM::PrivateKey>(
|
||||
group.generate_key_pair().release());
|
||||
}
|
||||
|
||||
std::unique_ptr<KEM::PrivateKey>
|
||||
DHKEM::derive_key_pair(const bytes& ikm) const
|
||||
{
|
||||
return std::make_unique<DHKEM::PrivateKey>(
|
||||
group.derive_key_pair(suite_id, ikm).release());
|
||||
}
|
||||
|
||||
bytes
|
||||
DHKEM::serialize(const KEM::PublicKey& pk) const
|
||||
{
|
||||
const auto& gpk = dynamic_cast<const Group::PublicKey&>(pk);
|
||||
return group.serialize(gpk);
|
||||
}
|
||||
|
||||
std::unique_ptr<KEM::PublicKey>
|
||||
DHKEM::deserialize(const bytes& enc) const
|
||||
{
|
||||
return group.deserialize(enc);
|
||||
}
|
||||
|
||||
bytes
|
||||
DHKEM::serialize_private(const KEM::PrivateKey& sk) const
|
||||
{
|
||||
const auto& gsk = dynamic_cast<const PrivateKey&>(sk);
|
||||
return group.serialize_private(*gsk.group_priv);
|
||||
}
|
||||
|
||||
std::unique_ptr<KEM::PrivateKey>
|
||||
DHKEM::deserialize_private(const bytes& skm) const
|
||||
{
|
||||
return std::make_unique<PrivateKey>(group.deserialize_private(skm).release());
|
||||
}
|
||||
|
||||
std::pair<bytes, bytes>
|
||||
DHKEM::encap(const KEM::PublicKey& pkR) const
|
||||
{
|
||||
const auto& gpkR = dynamic_cast<const Group::PublicKey&>(pkR);
|
||||
|
||||
auto skE = group.generate_key_pair();
|
||||
auto pkE = skE->public_key();
|
||||
|
||||
auto zz = group.dh(*skE, gpkR);
|
||||
auto enc = group.serialize(*pkE);
|
||||
|
||||
auto pkRm = group.serialize(gpkR);
|
||||
auto kem_context = enc + pkRm;
|
||||
|
||||
auto shared_secret = extract_and_expand(zz, kem_context);
|
||||
return std::make_pair(shared_secret, enc);
|
||||
}
|
||||
|
||||
bytes
|
||||
DHKEM::decap(const bytes& enc, const KEM::PrivateKey& skR) const
|
||||
{
|
||||
const auto& gskR = dynamic_cast<const PrivateKey&>(skR);
|
||||
auto pkR = gskR.group_priv->public_key();
|
||||
auto pkE = group.deserialize(enc);
|
||||
auto zz = group.dh(*gskR.group_priv, *pkE);
|
||||
|
||||
auto pkRm = group.serialize(*pkR);
|
||||
auto kem_context = enc + pkRm;
|
||||
return extract_and_expand(zz, kem_context);
|
||||
}
|
||||
|
||||
std::pair<bytes, bytes>
|
||||
DHKEM::auth_encap(const KEM::PublicKey& pkR, const KEM::PrivateKey& skS) const
|
||||
{
|
||||
const auto& gpkR = dynamic_cast<const Group::PublicKey&>(pkR);
|
||||
const auto& gskS = dynamic_cast<const PrivateKey&>(skS);
|
||||
|
||||
auto skE = group.generate_key_pair();
|
||||
auto pkE = skE->public_key();
|
||||
auto pkS = gskS.group_priv->public_key();
|
||||
|
||||
auto zzER = group.dh(*skE, gpkR);
|
||||
auto zzSR = group.dh(*gskS.group_priv, gpkR);
|
||||
auto zz = zzER + zzSR;
|
||||
auto enc = group.serialize(*pkE);
|
||||
|
||||
auto pkRm = group.serialize(gpkR);
|
||||
auto pkSm = group.serialize(*pkS);
|
||||
auto kem_context = enc + pkRm + pkSm;
|
||||
|
||||
auto shared_secret = extract_and_expand(zz, kem_context);
|
||||
return std::make_pair(shared_secret, enc);
|
||||
}
|
||||
|
||||
bytes
|
||||
DHKEM::auth_decap(const bytes& enc,
|
||||
const KEM::PublicKey& pkS,
|
||||
const KEM::PrivateKey& skR) const
|
||||
{
|
||||
const auto& gpkS = dynamic_cast<const Group::PublicKey&>(pkS);
|
||||
const auto& gskR = dynamic_cast<const PrivateKey&>(skR);
|
||||
|
||||
auto pkE = group.deserialize(enc);
|
||||
auto pkR = gskR.group_priv->public_key();
|
||||
|
||||
auto zzER = group.dh(*gskR.group_priv, *pkE);
|
||||
auto zzSR = group.dh(*gskR.group_priv, gpkS);
|
||||
auto zz = zzER + zzSR;
|
||||
|
||||
auto pkRm = group.serialize(*pkR);
|
||||
auto pkSm = group.serialize(gpkS);
|
||||
auto kem_context = enc + pkRm + pkSm;
|
||||
|
||||
return extract_and_expand(zz, kem_context);
|
||||
}
|
||||
|
||||
bytes
|
||||
DHKEM::extract_and_expand(const bytes& dh, const bytes& kem_context) const
|
||||
{
|
||||
static const auto label_eae_prk = from_ascii("eae_prk");
|
||||
static const auto label_shared_secret = from_ascii("shared_secret");
|
||||
|
||||
auto eae_prk = kdf.labeled_extract(suite_id, {}, label_eae_prk, dh);
|
||||
return kdf.labeled_expand(
|
||||
suite_id, eae_prk, label_shared_secret, kem_context, secret_size);
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
57
DPP/mlspp/lib/hpke/src/dhkem.h
Executable file
57
DPP/mlspp/lib/hpke/src/dhkem.h
Executable file
@@ -0,0 +1,57 @@
|
||||
#pragma once
|
||||
|
||||
#include <hpke/hpke.h>
|
||||
|
||||
#include "group.h"
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
struct DHKEM : public KEM
|
||||
{
|
||||
struct PrivateKey : public KEM::PrivateKey
|
||||
{
|
||||
PrivateKey(Group::PrivateKey* group_priv_in);
|
||||
std::unique_ptr<KEM::PublicKey> public_key() const override;
|
||||
|
||||
std::unique_ptr<Group::PrivateKey> group_priv;
|
||||
};
|
||||
|
||||
template<KEM::ID>
|
||||
static const DHKEM& get();
|
||||
|
||||
~DHKEM() override = default;
|
||||
|
||||
std::unique_ptr<KEM::PrivateKey> generate_key_pair() const override;
|
||||
std::unique_ptr<KEM::PrivateKey> derive_key_pair(
|
||||
const bytes& ikm) const override;
|
||||
|
||||
bytes serialize(const KEM::PublicKey& pk) const override;
|
||||
std::unique_ptr<KEM::PublicKey> deserialize(const bytes& enc) const override;
|
||||
|
||||
bytes serialize_private(const KEM::PrivateKey& sk) const override;
|
||||
std::unique_ptr<KEM::PrivateKey> deserialize_private(
|
||||
const bytes& skm) const override;
|
||||
|
||||
std::pair<bytes, bytes> encap(const KEM::PublicKey& pk) const override;
|
||||
bytes decap(const bytes& enc, const KEM::PrivateKey& sk) const override;
|
||||
|
||||
std::pair<bytes, bytes> auth_encap(const KEM::PublicKey& pkR,
|
||||
const KEM::PrivateKey& skS) const override;
|
||||
bytes auth_decap(const bytes& enc,
|
||||
const KEM::PublicKey& pkS,
|
||||
const KEM::PrivateKey& skR) const override;
|
||||
|
||||
private:
|
||||
const Group& group;
|
||||
const KDF& kdf;
|
||||
bytes suite_id;
|
||||
|
||||
bytes extract_and_expand(const bytes& dh, const bytes& kem_context) const;
|
||||
|
||||
DHKEM(KEM::ID kem_id_in, const Group& group_in, const KDF& kdf_in);
|
||||
friend DHKEM make_dhkem(KEM::ID kem_id_in,
|
||||
const Group& group_in,
|
||||
const KDF& kdf_in);
|
||||
};
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
187
DPP/mlspp/lib/hpke/src/digest.cpp
Executable file
187
DPP/mlspp/lib/hpke/src/digest.cpp
Executable file
@@ -0,0 +1,187 @@
|
||||
#include <hpke/digest.h>
|
||||
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/hmac.h>
|
||||
#if defined(WITH_OPENSSL3)
|
||||
#include <openssl/core_names.h>
|
||||
#endif
|
||||
|
||||
#include "openssl_common.h"
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
static const EVP_MD*
|
||||
openssl_digest_type(Digest::ID digest)
|
||||
{
|
||||
switch (digest) {
|
||||
case Digest::ID::SHA256:
|
||||
return EVP_sha256();
|
||||
|
||||
case Digest::ID::SHA384:
|
||||
return EVP_sha384();
|
||||
|
||||
case Digest::ID::SHA512:
|
||||
return EVP_sha512();
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported ciphersuite");
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(WITH_OPENSSL3)
|
||||
static std::string
|
||||
openssl_digest_name(Digest::ID digest)
|
||||
{
|
||||
switch (digest) {
|
||||
case Digest::ID::SHA256:
|
||||
return OSSL_DIGEST_NAME_SHA2_256;
|
||||
|
||||
case Digest::ID::SHA384:
|
||||
return OSSL_DIGEST_NAME_SHA2_384;
|
||||
|
||||
case Digest::ID::SHA512:
|
||||
return OSSL_DIGEST_NAME_SHA2_512;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unsupported digest algorithm");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template<>
|
||||
const Digest&
|
||||
Digest::get<Digest::ID::SHA256>()
|
||||
{
|
||||
static const Digest instance(Digest::ID::SHA256);
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const Digest&
|
||||
Digest::get<Digest::ID::SHA384>()
|
||||
{
|
||||
static const Digest instance(Digest::ID::SHA384);
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const Digest&
|
||||
Digest::get<Digest::ID::SHA512>()
|
||||
{
|
||||
static const Digest instance(Digest::ID::SHA512);
|
||||
return instance;
|
||||
}
|
||||
|
||||
Digest::Digest(Digest::ID id_in)
|
||||
: id(id_in)
|
||||
, hash_size(EVP_MD_size(openssl_digest_type(id_in)))
|
||||
{
|
||||
}
|
||||
|
||||
bytes
|
||||
Digest::hash(const bytes& data) const
|
||||
{
|
||||
auto md = bytes(hash_size);
|
||||
unsigned int size = 0;
|
||||
const auto* type = openssl_digest_type(id);
|
||||
if (1 !=
|
||||
EVP_Digest(data.data(), data.size(), md.data(), &size, type, nullptr)) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
return md;
|
||||
}
|
||||
|
||||
bytes
|
||||
Digest::hmac(const bytes& key, const bytes& data) const
|
||||
{
|
||||
auto md = bytes(hash_size);
|
||||
unsigned int size = 0;
|
||||
const auto* type = openssl_digest_type(id);
|
||||
if (nullptr == HMAC(type,
|
||||
key.data(),
|
||||
static_cast<int>(key.size()),
|
||||
data.data(),
|
||||
static_cast<int>(data.size()),
|
||||
md.data(),
|
||||
&size)) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
return md;
|
||||
}
|
||||
|
||||
bytes
|
||||
Digest::hmac_for_hkdf_extract(const bytes& key, const bytes& data) const
|
||||
{
|
||||
#if defined(WITH_OPENSSL3)
|
||||
auto digest_name = openssl_digest_name(id);
|
||||
std::array<OSSL_PARAM, 2> params = {
|
||||
OSSL_PARAM_construct_utf8_string(
|
||||
OSSL_ALG_PARAM_DIGEST, digest_name.data(), 0),
|
||||
OSSL_PARAM_construct_end()
|
||||
};
|
||||
const auto mac =
|
||||
make_typed_unique(EVP_MAC_fetch(nullptr, OSSL_MAC_NAME_HMAC, nullptr));
|
||||
const auto ctx = make_typed_unique(EVP_MAC_CTX_new(mac.get()));
|
||||
#else
|
||||
const auto* type = openssl_digest_type(id);
|
||||
auto ctx = make_typed_unique(HMAC_CTX_new());
|
||||
#endif
|
||||
if (ctx == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
// Some FIPS-enabled libraries are overly conservative in their interpretation
|
||||
// of NIST SP 800-131A, which requires HMAC keys to be at least 112 bits long.
|
||||
// That document does not impose that requirement on HKDF, so we disable FIPS
|
||||
// enforcement for purposes of HKDF.
|
||||
//
|
||||
// https://doi.org/10.6028/NIST.SP.800-131Ar2
|
||||
auto key_size = static_cast<int>(key.size());
|
||||
// OpenSSL 3 does not support the flag EVP_MD_CTX_FLAG_NON_FIPS_ALLOW anymore.
|
||||
// However, OpenSSL 3 in FIPS mode doesn't seem to check the HMAC key size
|
||||
// constraint.
|
||||
#if !defined(WITH_OPENSSL3) && !defined(WITH_BORINGSSL)
|
||||
static const auto fips_min_hmac_key_len = 14;
|
||||
if (FIPS_mode() != 0 && key_size < fips_min_hmac_key_len) {
|
||||
HMAC_CTX_set_flags(ctx.get(), EVP_MD_CTX_FLAG_NON_FIPS_ALLOW);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Guard against sending nullptr to HMAC_Init_ex
|
||||
const auto* key_data = key.data();
|
||||
const auto non_null_zero_length_key = uint8_t(0);
|
||||
if (key_data == nullptr) {
|
||||
key_data = &non_null_zero_length_key;
|
||||
}
|
||||
|
||||
auto md = bytes(hash_size);
|
||||
#if defined(WITH_OPENSSL3)
|
||||
if (1 != EVP_MAC_init(ctx.get(), key_data, key_size, params.data())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
if (1 != EVP_MAC_update(ctx.get(), data.data(), data.size())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
size_t size = 0;
|
||||
if (1 != EVP_MAC_final(ctx.get(), md.data(), &size, hash_size)) {
|
||||
throw openssl_error();
|
||||
}
|
||||
#else
|
||||
if (1 != HMAC_Init_ex(ctx.get(), key_data, key_size, type, nullptr)) {
|
||||
throw openssl_error();
|
||||
}
|
||||
if (1 != HMAC_Update(ctx.get(), data.data(), data.size())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
unsigned int size = 0;
|
||||
if (1 != HMAC_Final(ctx.get(), md.data(), &size)) {
|
||||
throw openssl_error();
|
||||
}
|
||||
#endif
|
||||
|
||||
return md;
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
1077
DPP/mlspp/lib/hpke/src/group.cpp
Executable file
1077
DPP/mlspp/lib/hpke/src/group.cpp
Executable file
File diff suppressed because it is too large
Load Diff
116
DPP/mlspp/lib/hpke/src/group.h
Executable file
116
DPP/mlspp/lib/hpke/src/group.h
Executable file
@@ -0,0 +1,116 @@
|
||||
#pragma once
|
||||
|
||||
#include <hpke/hpke.h>
|
||||
#include <hpke/signature.h>
|
||||
|
||||
#include "openssl_common.h"
|
||||
#include <openssl/evp.h>
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
struct Group
|
||||
{
|
||||
enum struct ID : uint8_t
|
||||
{
|
||||
P256,
|
||||
P384,
|
||||
P521,
|
||||
X25519,
|
||||
X448,
|
||||
Ed25519,
|
||||
Ed448,
|
||||
};
|
||||
|
||||
struct PublicKey
|
||||
: public KEM::PublicKey
|
||||
, public Signature::PublicKey
|
||||
{
|
||||
virtual ~PublicKey() = default;
|
||||
};
|
||||
|
||||
struct PrivateKey
|
||||
{
|
||||
virtual ~PrivateKey() = default;
|
||||
virtual std::unique_ptr<PublicKey> public_key() const = 0;
|
||||
};
|
||||
|
||||
template<Group::ID id>
|
||||
static const Group& get();
|
||||
|
||||
virtual ~Group() = default;
|
||||
|
||||
const ID id;
|
||||
const size_t dh_size;
|
||||
const size_t pk_size;
|
||||
const size_t sk_size;
|
||||
const std::string jwk_key_type;
|
||||
const std::string jwk_curve_name;
|
||||
|
||||
virtual std::unique_ptr<PrivateKey> generate_key_pair() const = 0;
|
||||
virtual std::unique_ptr<PrivateKey> derive_key_pair(
|
||||
const bytes& suite_id,
|
||||
const bytes& ikm) const = 0;
|
||||
|
||||
virtual bytes serialize(const PublicKey& pk) const = 0;
|
||||
virtual std::unique_ptr<PublicKey> deserialize(const bytes& enc) const = 0;
|
||||
|
||||
virtual bytes serialize_private(const PrivateKey& sk) const = 0;
|
||||
virtual std::unique_ptr<PrivateKey> deserialize_private(
|
||||
const bytes& skm) const = 0;
|
||||
|
||||
virtual bytes dh(const PrivateKey& sk, const PublicKey& pk) const = 0;
|
||||
|
||||
virtual bytes sign(const bytes& data, const PrivateKey& sk) const = 0;
|
||||
virtual bool verify(const bytes& data,
|
||||
const bytes& sig,
|
||||
const PublicKey& pk) const = 0;
|
||||
|
||||
virtual std::tuple<bytes, bytes> coordinates(const PublicKey& pk) const = 0;
|
||||
virtual std::unique_ptr<PublicKey> public_key_from_coordinates(
|
||||
const bytes& x,
|
||||
const bytes& y) const = 0;
|
||||
|
||||
protected:
|
||||
const KDF& kdf;
|
||||
|
||||
friend struct DHKEM;
|
||||
|
||||
Group(ID group_id_in, const KDF& kdf_in);
|
||||
};
|
||||
|
||||
struct EVPGroup : public Group
|
||||
{
|
||||
EVPGroup(Group::ID group_id, const KDF& kdf);
|
||||
|
||||
struct PublicKey : public Group::PublicKey
|
||||
{
|
||||
explicit PublicKey(EVP_PKEY* pkey_in);
|
||||
~PublicKey() override = default;
|
||||
|
||||
// NOLINTNEXTLINE(misc-non-private-member-variables-in-classes)
|
||||
typed_unique_ptr<EVP_PKEY> pkey;
|
||||
};
|
||||
|
||||
struct PrivateKey : public Group::PrivateKey
|
||||
{
|
||||
explicit PrivateKey(EVP_PKEY* pkey_in);
|
||||
~PrivateKey() override = default;
|
||||
|
||||
std::unique_ptr<Group::PublicKey> public_key() const override;
|
||||
|
||||
// NOLINTNEXTLINE(misc-non-private-member-variables-in-classes)
|
||||
typed_unique_ptr<EVP_PKEY> pkey;
|
||||
};
|
||||
|
||||
std::unique_ptr<Group::PrivateKey> generate_key_pair() const override;
|
||||
|
||||
bytes dh(const Group::PrivateKey& sk,
|
||||
const Group::PublicKey& pk) const override;
|
||||
|
||||
bytes sign(const bytes& data, const Group::PrivateKey& sk) const override;
|
||||
bool verify(const bytes& data,
|
||||
const bytes& sig,
|
||||
const Group::PublicKey& pk) const override;
|
||||
};
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
79
DPP/mlspp/lib/hpke/src/hkdf.cpp
Executable file
79
DPP/mlspp/lib/hpke/src/hkdf.cpp
Executable file
@@ -0,0 +1,79 @@
|
||||
#include "hkdf.h"
|
||||
#include "openssl_common.h"
|
||||
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
template<>
|
||||
const HKDF&
|
||||
HKDF::get<Digest::ID::SHA256>()
|
||||
{
|
||||
static const HKDF instance(Digest::get<Digest::ID::SHA256>());
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const HKDF&
|
||||
HKDF::get<Digest::ID::SHA384>()
|
||||
{
|
||||
static const HKDF instance(Digest::get<Digest::ID::SHA384>());
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const HKDF&
|
||||
HKDF::get<Digest::ID::SHA512>()
|
||||
{
|
||||
static const HKDF instance(Digest::get<Digest::ID::SHA512>());
|
||||
return instance;
|
||||
}
|
||||
|
||||
static KDF::ID
|
||||
digest_to_kdf(Digest::ID digest_id)
|
||||
{
|
||||
switch (digest_id) {
|
||||
case Digest::ID::SHA256:
|
||||
return KDF::ID::HKDF_SHA256;
|
||||
case Digest::ID::SHA384:
|
||||
return KDF::ID::HKDF_SHA384;
|
||||
case Digest::ID::SHA512:
|
||||
return KDF::ID::HKDF_SHA512;
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unsupported algorithm");
|
||||
}
|
||||
|
||||
HKDF::HKDF(const Digest& digest_in)
|
||||
: KDF(digest_to_kdf(digest_in.id), digest_in.hash_size)
|
||||
, digest(digest_in)
|
||||
{
|
||||
}
|
||||
|
||||
bytes
|
||||
HKDF::extract(const bytes& salt, const bytes& ikm) const
|
||||
{
|
||||
return digest.hmac_for_hkdf_extract(salt, ikm);
|
||||
}
|
||||
|
||||
bytes
|
||||
HKDF::expand(const bytes& prk, const bytes& info, size_t size) const
|
||||
{
|
||||
auto okm = bytes{};
|
||||
auto i = uint8_t(0x00);
|
||||
auto Ti = bytes{};
|
||||
while (okm.size() < size) {
|
||||
i += 1;
|
||||
auto block = Ti + info + bytes{ i };
|
||||
|
||||
Ti = digest.hmac(prk, block);
|
||||
okm += Ti;
|
||||
}
|
||||
|
||||
okm.resize(size);
|
||||
return okm;
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
24
DPP/mlspp/lib/hpke/src/hkdf.h
Executable file
24
DPP/mlspp/lib/hpke/src/hkdf.h
Executable file
@@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include <hpke/digest.h>
|
||||
#include <hpke/hpke.h>
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
struct HKDF : public KDF
|
||||
{
|
||||
template<Digest::ID digest_id>
|
||||
static const HKDF& get();
|
||||
|
||||
~HKDF() override = default;
|
||||
|
||||
bytes extract(const bytes& salt, const bytes& ikm) const override;
|
||||
bytes expand(const bytes& prk, const bytes& info, size_t size) const override;
|
||||
|
||||
private:
|
||||
const Digest& digest;
|
||||
|
||||
explicit HKDF(const Digest& digest_in);
|
||||
};
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
540
DPP/mlspp/lib/hpke/src/hpke.cpp
Executable file
540
DPP/mlspp/lib/hpke/src/hpke.cpp
Executable file
@@ -0,0 +1,540 @@
|
||||
#include <hpke/digest.h>
|
||||
#include <hpke/hpke.h>
|
||||
|
||||
#include "aead_cipher.h"
|
||||
#include "common.h"
|
||||
#include "dhkem.h"
|
||||
#include "hkdf.h"
|
||||
|
||||
#include <limits>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
///
|
||||
/// Helper functions and constants
|
||||
///
|
||||
|
||||
static const bytes&
|
||||
label_exp()
|
||||
{
|
||||
static const bytes val = from_ascii("exp");
|
||||
return val;
|
||||
}
|
||||
|
||||
static const bytes&
|
||||
label_hpke()
|
||||
{
|
||||
static const bytes val = from_ascii("HPKE");
|
||||
return val;
|
||||
}
|
||||
|
||||
static const bytes&
|
||||
label_hpke_version()
|
||||
{
|
||||
static const bytes val = from_ascii("HPKE-v1");
|
||||
return val;
|
||||
}
|
||||
|
||||
static const bytes&
|
||||
label_info_hash()
|
||||
{
|
||||
static const bytes val = from_ascii("info_hash");
|
||||
return val;
|
||||
}
|
||||
|
||||
static const bytes&
|
||||
label_key()
|
||||
{
|
||||
static const bytes val = from_ascii("key");
|
||||
return val;
|
||||
}
|
||||
|
||||
static const bytes&
|
||||
label_base_nonce()
|
||||
{
|
||||
static const bytes val = from_ascii("base_nonce");
|
||||
return val;
|
||||
}
|
||||
|
||||
static const bytes&
|
||||
label_psk_id_hash()
|
||||
{
|
||||
static const bytes val = from_ascii("psk_id_hash");
|
||||
return val;
|
||||
}
|
||||
|
||||
static const bytes&
|
||||
label_sec()
|
||||
{
|
||||
static const bytes val = from_ascii("sec");
|
||||
return val;
|
||||
}
|
||||
|
||||
static const bytes&
|
||||
label_secret()
|
||||
{
|
||||
static const bytes val = from_ascii("secret");
|
||||
return val;
|
||||
}
|
||||
|
||||
///
|
||||
/// Factory methods for primitives
|
||||
///
|
||||
|
||||
KEM::KEM(ID id_in,
|
||||
size_t secret_size_in,
|
||||
size_t enc_size_in,
|
||||
size_t pk_size_in,
|
||||
size_t sk_size_in)
|
||||
: id(id_in)
|
||||
, secret_size(secret_size_in)
|
||||
, enc_size(enc_size_in)
|
||||
, pk_size(pk_size_in)
|
||||
, sk_size(sk_size_in)
|
||||
{
|
||||
}
|
||||
|
||||
template<>
|
||||
const KEM&
|
||||
KEM::get<KEM::ID::DHKEM_P256_SHA256>()
|
||||
{
|
||||
return DHKEM::get<KEM::ID::DHKEM_P256_SHA256>();
|
||||
}
|
||||
|
||||
template<>
|
||||
const KEM&
|
||||
KEM::get<KEM::ID::DHKEM_P384_SHA384>()
|
||||
{
|
||||
return DHKEM::get<KEM::ID::DHKEM_P384_SHA384>();
|
||||
}
|
||||
|
||||
template<>
|
||||
const KEM&
|
||||
KEM::get<KEM::ID::DHKEM_P521_SHA512>()
|
||||
{
|
||||
return DHKEM::get<KEM::ID::DHKEM_P521_SHA512>();
|
||||
}
|
||||
|
||||
template<>
|
||||
const KEM&
|
||||
KEM::get<KEM::ID::DHKEM_X25519_SHA256>()
|
||||
{
|
||||
return DHKEM::get<KEM::ID::DHKEM_X25519_SHA256>();
|
||||
}
|
||||
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
template<>
|
||||
const KEM&
|
||||
KEM::get<KEM::ID::DHKEM_X448_SHA512>()
|
||||
{
|
||||
return DHKEM::get<KEM::ID::DHKEM_X448_SHA512>();
|
||||
}
|
||||
#endif
|
||||
|
||||
bytes
|
||||
KEM::serialize_private(const KEM::PrivateKey& /* unused */) const
|
||||
{
|
||||
throw std::runtime_error("Not implemented");
|
||||
}
|
||||
|
||||
std::unique_ptr<KEM::PrivateKey>
|
||||
KEM::deserialize_private(const bytes& /* unused */) const
|
||||
{
|
||||
throw std::runtime_error("Not implemented");
|
||||
}
|
||||
|
||||
std::pair<bytes, bytes>
|
||||
KEM::auth_encap(const PublicKey& /* unused */,
|
||||
const PrivateKey& /* unused */) const
|
||||
{
|
||||
throw std::runtime_error("Not implemented");
|
||||
}
|
||||
|
||||
bytes
|
||||
KEM::auth_decap(const bytes& /* unused */,
|
||||
const PublicKey& /* unused */,
|
||||
const PrivateKey& /* unused */) const
|
||||
{
|
||||
throw std::runtime_error("Not implemented");
|
||||
}
|
||||
|
||||
template<>
|
||||
const KDF&
|
||||
KDF::get<KDF::ID::HKDF_SHA256>()
|
||||
{
|
||||
return HKDF::get<Digest::ID::SHA256>();
|
||||
}
|
||||
|
||||
template<>
|
||||
const KDF&
|
||||
KDF::get<KDF::ID::HKDF_SHA384>()
|
||||
{
|
||||
return HKDF::get<Digest::ID::SHA384>();
|
||||
}
|
||||
|
||||
template<>
|
||||
const KDF&
|
||||
KDF::get<KDF::ID::HKDF_SHA512>()
|
||||
{
|
||||
return HKDF::get<Digest::ID::SHA512>();
|
||||
}
|
||||
|
||||
KDF::KDF(ID id_in, size_t hash_size_in)
|
||||
: id(id_in)
|
||||
, hash_size(hash_size_in)
|
||||
{
|
||||
}
|
||||
|
||||
bytes
|
||||
KDF::labeled_extract(const bytes& suite_id,
|
||||
const bytes& salt,
|
||||
const bytes& label,
|
||||
const bytes& ikm) const
|
||||
{
|
||||
auto labeled_ikm = label_hpke_version() + suite_id + label + ikm;
|
||||
return extract(salt, labeled_ikm);
|
||||
}
|
||||
|
||||
bytes
|
||||
KDF::labeled_expand(const bytes& suite_id,
|
||||
const bytes& prk,
|
||||
const bytes& label,
|
||||
const bytes& info,
|
||||
size_t size) const
|
||||
{
|
||||
auto labeled_info =
|
||||
i2osp(size, 2) + label_hpke_version() + suite_id + label + info;
|
||||
return expand(prk, labeled_info, size);
|
||||
}
|
||||
|
||||
template<>
|
||||
const AEAD&
|
||||
AEAD::get<AEAD::ID::AES_128_GCM>()
|
||||
{
|
||||
return AEADCipher::get<AEAD::ID::AES_128_GCM>();
|
||||
}
|
||||
|
||||
template<>
|
||||
const AEAD&
|
||||
AEAD::get<AEAD::ID::AES_256_GCM>()
|
||||
{
|
||||
return AEADCipher::get<AEAD::ID::AES_256_GCM>();
|
||||
}
|
||||
|
||||
template<>
|
||||
const AEAD&
|
||||
AEAD::get<AEAD::ID::CHACHA20_POLY1305>()
|
||||
{
|
||||
return AEADCipher::get<AEAD::ID::CHACHA20_POLY1305>();
|
||||
}
|
||||
|
||||
template<>
|
||||
const AEAD&
|
||||
AEAD::get<AEAD::ID::export_only>()
|
||||
{
|
||||
static const auto export_only = ExportOnlyCipher{};
|
||||
return export_only;
|
||||
}
|
||||
|
||||
AEAD::AEAD(ID id_in, size_t key_size_in, size_t nonce_size_in)
|
||||
: id(id_in)
|
||||
, key_size(key_size_in)
|
||||
, nonce_size(nonce_size_in)
|
||||
{
|
||||
}
|
||||
|
||||
///
|
||||
/// Encryption Contexts
|
||||
///
|
||||
|
||||
bytes
|
||||
Context::do_export(const bytes& exporter_context, size_t size) const
|
||||
{
|
||||
return kdf.labeled_expand(
|
||||
suite, exporter_secret, label_sec(), exporter_context, size);
|
||||
}
|
||||
|
||||
bytes
|
||||
Context::current_nonce() const
|
||||
{
|
||||
auto curr = i2osp(seq, aead.nonce_size);
|
||||
return curr ^ nonce;
|
||||
}
|
||||
|
||||
void
|
||||
Context::increment_seq()
|
||||
{
|
||||
if (seq == std::numeric_limits<uint64_t>::max()) {
|
||||
throw std::runtime_error("Sequence number overflow");
|
||||
}
|
||||
|
||||
seq += 1;
|
||||
}
|
||||
|
||||
Context::Context(bytes suite_in,
|
||||
bytes key_in,
|
||||
bytes nonce_in,
|
||||
bytes exporter_secret_in,
|
||||
const KDF& kdf_in,
|
||||
const AEAD& aead_in)
|
||||
: suite(std::move(suite_in))
|
||||
, key(std::move(key_in))
|
||||
, nonce(std::move(nonce_in))
|
||||
, exporter_secret(std::move(exporter_secret_in))
|
||||
, kdf(kdf_in)
|
||||
, aead(aead_in)
|
||||
, seq(0)
|
||||
{
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const Context& lhs, const Context& rhs)
|
||||
{
|
||||
// TODO(RLB) Compare KDF and AEAD algorithms
|
||||
auto suite = (lhs.suite == rhs.suite);
|
||||
auto key = (lhs.key == rhs.key);
|
||||
auto nonce = (lhs.nonce == rhs.nonce);
|
||||
auto exporter_secret = (lhs.exporter_secret == rhs.exporter_secret);
|
||||
auto seq = (lhs.seq == rhs.seq);
|
||||
return suite && key && nonce && exporter_secret && seq;
|
||||
}
|
||||
|
||||
SenderContext::SenderContext(Context&& c)
|
||||
: Context(std::move(c))
|
||||
{
|
||||
}
|
||||
|
||||
bytes
|
||||
SenderContext::seal(const bytes& aad, const bytes& pt)
|
||||
{
|
||||
auto ct = aead.seal(key, current_nonce(), aad, pt);
|
||||
increment_seq();
|
||||
return ct;
|
||||
}
|
||||
|
||||
ReceiverContext::ReceiverContext(Context&& c)
|
||||
: Context(std::move(c))
|
||||
{
|
||||
}
|
||||
|
||||
std::optional<bytes>
|
||||
ReceiverContext::open(const bytes& aad, const bytes& ct)
|
||||
{
|
||||
auto maybe_pt = aead.open(key, current_nonce(), aad, ct);
|
||||
increment_seq();
|
||||
return maybe_pt;
|
||||
}
|
||||
|
||||
///
|
||||
/// HPKE
|
||||
///
|
||||
|
||||
static const bytes default_psk = {};
|
||||
static const bytes default_psk_id = {};
|
||||
|
||||
static bytes
|
||||
suite_id(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id)
|
||||
{
|
||||
return label_hpke() + i2osp(static_cast<uint64_t>(kem_id), 2) +
|
||||
i2osp(static_cast<uint64_t>(kdf_id), 2) +
|
||||
i2osp(static_cast<uint64_t>(aead_id), 2);
|
||||
}
|
||||
|
||||
static const KEM&
|
||||
select_kem(KEM::ID id)
|
||||
{
|
||||
switch (id) {
|
||||
case KEM::ID::DHKEM_P256_SHA256:
|
||||
return KEM::get<KEM::ID::DHKEM_P256_SHA256>();
|
||||
case KEM::ID::DHKEM_P384_SHA384:
|
||||
return KEM::get<KEM::ID::DHKEM_P384_SHA384>();
|
||||
case KEM::ID::DHKEM_P521_SHA512:
|
||||
return KEM::get<KEM::ID::DHKEM_P521_SHA512>();
|
||||
case KEM::ID::DHKEM_X25519_SHA256:
|
||||
return KEM::get<KEM::ID::DHKEM_X25519_SHA256>();
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
case KEM::ID::DHKEM_X448_SHA512:
|
||||
return KEM::get<KEM::ID::DHKEM_X448_SHA512>();
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
|
||||
static const KDF&
|
||||
select_kdf(KDF::ID id)
|
||||
{
|
||||
switch (id) {
|
||||
case KDF::ID::HKDF_SHA256:
|
||||
return KDF::get<KDF::ID::HKDF_SHA256>();
|
||||
case KDF::ID::HKDF_SHA384:
|
||||
return KDF::get<KDF::ID::HKDF_SHA384>();
|
||||
case KDF::ID::HKDF_SHA512:
|
||||
return KDF::get<KDF::ID::HKDF_SHA512>();
|
||||
default:
|
||||
throw std::runtime_error("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
|
||||
static const AEAD&
|
||||
select_aead(AEAD::ID id)
|
||||
{
|
||||
switch (id) {
|
||||
case AEAD::ID::AES_128_GCM:
|
||||
return AEAD::get<AEAD::ID::AES_128_GCM>();
|
||||
case AEAD::ID::AES_256_GCM:
|
||||
return AEAD::get<AEAD::ID::AES_256_GCM>();
|
||||
case AEAD::ID::CHACHA20_POLY1305:
|
||||
return AEAD::get<AEAD::ID::CHACHA20_POLY1305>();
|
||||
case AEAD::ID::export_only:
|
||||
return AEAD::get<AEAD::ID::export_only>();
|
||||
default:
|
||||
throw std::runtime_error("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
|
||||
HPKE::HPKE(KEM::ID kem_id, KDF::ID kdf_id, AEAD::ID aead_id)
|
||||
: suite(suite_id(kem_id, kdf_id, aead_id))
|
||||
, kem(select_kem(kem_id))
|
||||
, kdf(select_kdf(kdf_id))
|
||||
, aead(select_aead(aead_id))
|
||||
{
|
||||
}
|
||||
|
||||
HPKE::SenderInfo
|
||||
HPKE::setup_base_s(const KEM::PublicKey& pkR, const bytes& info) const
|
||||
{
|
||||
auto [shared_secret, enc] = kem.encap(pkR);
|
||||
auto ctx =
|
||||
key_schedule(Mode::base, shared_secret, info, default_psk, default_psk_id);
|
||||
return std::make_pair(enc, SenderContext(std::move(ctx)));
|
||||
}
|
||||
|
||||
ReceiverContext
|
||||
HPKE::setup_base_r(const bytes& enc,
|
||||
const KEM::PrivateKey& skR,
|
||||
const bytes& info) const
|
||||
{
|
||||
auto pkRm = kem.serialize(*skR.public_key());
|
||||
auto shared_secret = kem.decap(enc, skR);
|
||||
auto ctx =
|
||||
key_schedule(Mode::base, shared_secret, info, default_psk, default_psk_id);
|
||||
return { std::move(ctx) };
|
||||
}
|
||||
|
||||
HPKE::SenderInfo
|
||||
HPKE::setup_psk_s(const KEM::PublicKey& pkR,
|
||||
const bytes& info,
|
||||
const bytes& psk,
|
||||
const bytes& psk_id) const
|
||||
{
|
||||
auto [shared_secret, enc] = kem.encap(pkR);
|
||||
auto ctx = key_schedule(Mode::psk, shared_secret, info, psk, psk_id);
|
||||
return std::make_pair(enc, SenderContext(std::move(ctx)));
|
||||
}
|
||||
|
||||
ReceiverContext
|
||||
HPKE::setup_psk_r(const bytes& enc,
|
||||
const KEM::PrivateKey& skR,
|
||||
const bytes& info,
|
||||
const bytes& psk,
|
||||
const bytes& psk_id) const
|
||||
{
|
||||
auto shared_secret = kem.decap(enc, skR);
|
||||
auto ctx = key_schedule(Mode::psk, shared_secret, info, psk, psk_id);
|
||||
return { std::move(ctx) };
|
||||
}
|
||||
|
||||
HPKE::SenderInfo
|
||||
HPKE::setup_auth_s(const KEM::PublicKey& pkR,
|
||||
const bytes& info,
|
||||
const KEM::PrivateKey& skS) const
|
||||
{
|
||||
auto [shared_secret, enc] = kem.auth_encap(pkR, skS);
|
||||
auto ctx =
|
||||
key_schedule(Mode::auth, shared_secret, info, default_psk, default_psk_id);
|
||||
return std::make_pair(enc, SenderContext(std::move(ctx)));
|
||||
}
|
||||
|
||||
ReceiverContext
|
||||
HPKE::setup_auth_r(const bytes& enc,
|
||||
const KEM::PrivateKey& skR,
|
||||
const bytes& info,
|
||||
const KEM::PublicKey& pkS) const
|
||||
{
|
||||
auto shared_secret = kem.auth_decap(enc, pkS, skR);
|
||||
auto ctx =
|
||||
key_schedule(Mode::auth, shared_secret, info, default_psk, default_psk_id);
|
||||
return { std::move(ctx) };
|
||||
}
|
||||
|
||||
HPKE::SenderInfo
|
||||
HPKE::setup_auth_psk_s(const KEM::PublicKey& pkR,
|
||||
const bytes& info,
|
||||
const bytes& psk,
|
||||
const bytes& psk_id,
|
||||
const KEM::PrivateKey& skS) const
|
||||
{
|
||||
auto [shared_secret, enc] = kem.auth_encap(pkR, skS);
|
||||
auto ctx = key_schedule(Mode::auth_psk, shared_secret, info, psk, psk_id);
|
||||
return std::make_pair(enc, SenderContext(std::move(ctx)));
|
||||
}
|
||||
|
||||
ReceiverContext
|
||||
HPKE::setup_auth_psk_r(const bytes& enc,
|
||||
const KEM::PrivateKey& skR,
|
||||
const bytes& info,
|
||||
const bytes& psk,
|
||||
const bytes& psk_id,
|
||||
const KEM::PublicKey& pkS) const
|
||||
{
|
||||
auto shared_secret = kem.auth_decap(enc, pkS, skR);
|
||||
auto ctx = key_schedule(Mode::auth_psk, shared_secret, info, psk, psk_id);
|
||||
return { std::move(ctx) };
|
||||
}
|
||||
|
||||
bool
|
||||
HPKE::verify_psk_inputs(Mode mode, const bytes& psk, const bytes& psk_id)
|
||||
{
|
||||
auto got_psk = (psk != default_psk);
|
||||
auto got_psk_id = (psk_id != default_psk_id);
|
||||
if (got_psk != got_psk_id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (!got_psk && (mode == Mode::base || mode == Mode::auth)) ||
|
||||
(got_psk && (mode == Mode::psk || mode == Mode::auth_psk));
|
||||
}
|
||||
|
||||
Context
|
||||
HPKE::key_schedule(Mode mode,
|
||||
const bytes& shared_secret,
|
||||
const bytes& info,
|
||||
const bytes& psk,
|
||||
const bytes& psk_id) const
|
||||
{
|
||||
if (!verify_psk_inputs(mode, psk, psk_id)) {
|
||||
throw std::runtime_error("Invalid PSK inputs");
|
||||
}
|
||||
|
||||
auto psk_id_hash =
|
||||
kdf.labeled_extract(suite, {}, label_psk_id_hash(), psk_id);
|
||||
auto info_hash = kdf.labeled_extract(suite, {}, label_info_hash(), info);
|
||||
auto mode_bytes = bytes{ uint8_t(mode) };
|
||||
auto key_schedule_context = mode_bytes + psk_id_hash + info_hash;
|
||||
|
||||
auto secret = kdf.labeled_extract(suite, shared_secret, label_secret(), psk);
|
||||
|
||||
auto key = kdf.labeled_expand(
|
||||
suite, secret, label_key(), key_schedule_context, aead.key_size);
|
||||
auto nonce = kdf.labeled_expand(
|
||||
suite, secret, label_base_nonce(), key_schedule_context, aead.nonce_size);
|
||||
auto exporter_secret = kdf.labeled_expand(
|
||||
suite, secret, label_exp(), key_schedule_context, kdf.hash_size);
|
||||
|
||||
return { suite, key, nonce, exporter_secret, kdf, aead };
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
160
DPP/mlspp/lib/hpke/src/openssl_common.cpp
Executable file
160
DPP/mlspp/lib/hpke/src/openssl_common.cpp
Executable file
@@ -0,0 +1,160 @@
|
||||
#include "openssl_common.h"
|
||||
|
||||
#include <openssl/ec.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/hmac.h>
|
||||
#include <openssl/x509.h>
|
||||
#include <openssl/x509v3.h>
|
||||
#if defined(WITH_OPENSSL3)
|
||||
#include <openssl/param_build.h>
|
||||
#endif
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(EVP_CIPHER_CTX* ptr)
|
||||
{
|
||||
EVP_CIPHER_CTX_free(ptr);
|
||||
}
|
||||
|
||||
#if WITH_BORINGSSL
|
||||
template<>
|
||||
void
|
||||
typed_delete(EVP_AEAD_CTX* ptr)
|
||||
{
|
||||
EVP_AEAD_CTX_free(ptr);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(EVP_PKEY_CTX* ptr)
|
||||
{
|
||||
EVP_PKEY_CTX_free(ptr);
|
||||
}
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(EVP_MD_CTX* ptr)
|
||||
{
|
||||
EVP_MD_CTX_free(ptr);
|
||||
}
|
||||
|
||||
#if !defined(WITH_OPENSSL3)
|
||||
template<>
|
||||
void
|
||||
typed_delete(HMAC_CTX* ptr)
|
||||
{
|
||||
HMAC_CTX_free(ptr);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(EVP_PKEY* ptr)
|
||||
{
|
||||
EVP_PKEY_free(ptr);
|
||||
}
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(BIGNUM* ptr)
|
||||
{
|
||||
BN_free(ptr);
|
||||
}
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(EC_POINT* ptr)
|
||||
{
|
||||
EC_POINT_free(ptr);
|
||||
}
|
||||
|
||||
#if !defined(WITH_OPENSSL3)
|
||||
template<>
|
||||
void
|
||||
typed_delete(EC_KEY* ptr)
|
||||
{
|
||||
EC_KEY_free(ptr);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(WITH_OPENSSL3)
|
||||
template<>
|
||||
void
|
||||
typed_delete(EVP_MAC* ptr)
|
||||
{
|
||||
EVP_MAC_free(ptr);
|
||||
}
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(EVP_MAC_CTX* ptr)
|
||||
{
|
||||
EVP_MAC_CTX_free(ptr);
|
||||
}
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(EC_GROUP* ptr)
|
||||
{
|
||||
EC_GROUP_free(ptr);
|
||||
}
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(OSSL_PARAM_BLD* ptr)
|
||||
{
|
||||
OSSL_PARAM_BLD_free(ptr);
|
||||
}
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(OSSL_PARAM* ptr)
|
||||
{
|
||||
OSSL_PARAM_free(ptr);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(X509* ptr)
|
||||
{
|
||||
X509_free(ptr);
|
||||
}
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(STACK_OF(GENERAL_NAME) * ptr)
|
||||
{
|
||||
sk_GENERAL_NAME_pop_free(ptr, GENERAL_NAME_free);
|
||||
}
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(BIO* ptr)
|
||||
{
|
||||
BIO_vfree(ptr);
|
||||
}
|
||||
|
||||
template<>
|
||||
void
|
||||
typed_delete(ASN1_TIME* ptr)
|
||||
{
|
||||
ASN1_TIME_free(ptr);
|
||||
}
|
||||
|
||||
///
|
||||
/// Map OpenSSL errors to C++ exceptions
|
||||
///
|
||||
|
||||
std::runtime_error
|
||||
openssl_error()
|
||||
{
|
||||
auto code = ERR_get_error();
|
||||
return std::runtime_error(ERR_error_string(code, nullptr));
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
25
DPP/mlspp/lib/hpke/src/openssl_common.h
Executable file
25
DPP/mlspp/lib/hpke/src/openssl_common.h
Executable file
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <hpke/hpke.h>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
template<typename T>
|
||||
void
|
||||
typed_delete(T* ptr);
|
||||
|
||||
template<typename T>
|
||||
using typed_unique_ptr = std::unique_ptr<T, decltype(&typed_delete<T>)>;
|
||||
|
||||
template<typename T>
|
||||
typed_unique_ptr<T>
|
||||
make_typed_unique(T* ptr)
|
||||
{
|
||||
return typed_unique_ptr<T>(ptr, typed_delete<T>);
|
||||
}
|
||||
|
||||
std::runtime_error
|
||||
openssl_error();
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
19
DPP/mlspp/lib/hpke/src/random.cpp
Executable file
19
DPP/mlspp/lib/hpke/src/random.cpp
Executable file
@@ -0,0 +1,19 @@
|
||||
#include <hpke/random.h>
|
||||
|
||||
#include "openssl_common.h"
|
||||
|
||||
#include <openssl/rand.h>
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
bytes
|
||||
random_bytes(size_t size)
|
||||
{
|
||||
auto rand = bytes(size);
|
||||
if (1 != RAND_bytes(rand.data(), static_cast<int>(size))) {
|
||||
throw openssl_error();
|
||||
}
|
||||
return rand;
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
207
DPP/mlspp/lib/hpke/src/rsa.cpp
Executable file
207
DPP/mlspp/lib/hpke/src/rsa.cpp
Executable file
@@ -0,0 +1,207 @@
|
||||
#include "rsa.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "openssl/rsa.h"
|
||||
#include "openssl_common.h"
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey>
|
||||
RSASignature::generate_key_pair() const
|
||||
{
|
||||
throw std::runtime_error("Not implemented");
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey>
|
||||
RSASignature::derive_key_pair(const bytes& /*ikm*/) const
|
||||
{
|
||||
throw std::runtime_error("Not implemented");
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey>
|
||||
RSASignature::generate_key_pair(size_t bits)
|
||||
{
|
||||
auto ctx = make_typed_unique(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr));
|
||||
if (ctx == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
if (EVP_PKEY_keygen_init(ctx.get()) <= 0) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(hicpp-signed-bitwise)
|
||||
if (EVP_PKEY_CTX_set_rsa_keygen_bits(ctx.get(), static_cast<int>(bits)) <=
|
||||
0) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
auto* pkey = static_cast<EVP_PKEY*>(nullptr);
|
||||
if (EVP_PKEY_keygen(ctx.get(), &pkey) <= 0) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
return std::make_unique<PrivateKey>(pkey);
|
||||
}
|
||||
|
||||
// TODO(rlb): Implement derive() with sizes
|
||||
|
||||
bytes
|
||||
RSASignature::serialize(const Signature::PublicKey& pk) const
|
||||
{
|
||||
const auto& rpk = dynamic_cast<const PublicKey&>(pk);
|
||||
const int len = i2d_PublicKey(rpk.pkey.get(), nullptr);
|
||||
auto raw = bytes(len);
|
||||
auto* data_ptr = raw.data();
|
||||
if (len != i2d_PublicKey(rpk.pkey.get(), &data_ptr)) {
|
||||
throw openssl_error();
|
||||
}
|
||||
return raw;
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PublicKey>
|
||||
RSASignature::deserialize(const bytes& enc) const
|
||||
{
|
||||
const auto* data_ptr = enc.data();
|
||||
auto* pkey = d2i_PublicKey(
|
||||
EVP_PKEY_RSA, nullptr, &data_ptr, static_cast<int>(enc.size()));
|
||||
if (pkey == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
return std::make_unique<RSASignature::PublicKey>(pkey);
|
||||
}
|
||||
|
||||
bytes
|
||||
RSASignature::serialize_private(const Signature::PrivateKey& sk) const
|
||||
{
|
||||
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
|
||||
const int len = i2d_PrivateKey(rsk.pkey.get(), nullptr);
|
||||
auto raw = bytes(len);
|
||||
auto* data_ptr = raw.data();
|
||||
if (len != i2d_PrivateKey(rsk.pkey.get(), &data_ptr)) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
return raw;
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey>
|
||||
RSASignature::deserialize_private(const bytes& skm) const
|
||||
{
|
||||
const auto* data_ptr = skm.data();
|
||||
auto* pkey = d2i_PrivateKey(
|
||||
EVP_PKEY_RSA, nullptr, &data_ptr, static_cast<int>(skm.size()));
|
||||
if (pkey == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
return std::make_unique<RSASignature::PrivateKey>(pkey);
|
||||
}
|
||||
|
||||
bytes
|
||||
RSASignature::sign(const bytes& data, const Signature::PrivateKey& sk) const
|
||||
{
|
||||
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
|
||||
|
||||
auto ctx = make_typed_unique(EVP_MD_CTX_create());
|
||||
if (ctx == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
if (1 !=
|
||||
EVP_DigestSignInit(ctx.get(), nullptr, md, nullptr, rsk.pkey.get())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
size_t siglen = EVP_PKEY_size(rsk.pkey.get());
|
||||
bytes sig(siglen);
|
||||
if (1 != EVP_DigestSign(
|
||||
ctx.get(), sig.data(), &siglen, data.data(), data.size())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
sig.resize(siglen);
|
||||
return sig;
|
||||
}
|
||||
|
||||
bool
|
||||
RSASignature::verify(const bytes& data,
|
||||
const bytes& sig,
|
||||
const Signature::PublicKey& pk) const
|
||||
{
|
||||
const auto& rpk = dynamic_cast<const PublicKey&>(pk);
|
||||
|
||||
auto ctx = make_typed_unique(EVP_MD_CTX_create());
|
||||
if (ctx == nullptr) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
if (1 !=
|
||||
EVP_DigestVerifyInit(ctx.get(), nullptr, md, nullptr, rpk.pkey.get())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
|
||||
auto rv = EVP_DigestVerify(
|
||||
ctx.get(), sig.data(), sig.size(), data.data(), data.size());
|
||||
|
||||
return rv == 1;
|
||||
}
|
||||
|
||||
// TODO(RLB) Implement these methods. No concrete need, but might be nice for
|
||||
// completeness.
|
||||
std::unique_ptr<Signature::PrivateKey>
|
||||
RSASignature::import_jwk_private(const std::string& /* json_str */) const
|
||||
{
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PublicKey>
|
||||
RSASignature::import_jwk(const std::string& /* json_str */) const
|
||||
{
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
std::string
|
||||
RSASignature::export_jwk_private(const Signature::PrivateKey& /* sk */) const
|
||||
{
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
std::string
|
||||
RSASignature::export_jwk(const Signature::PublicKey& /* pk */) const
|
||||
{
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
const EVP_MD*
|
||||
RSASignature::digest_to_md(Digest::ID digest)
|
||||
{
|
||||
// NOLINTNEXTLINE(hicpp-multiway-paths-covered)
|
||||
switch (digest) {
|
||||
case Digest::ID::SHA256:
|
||||
return EVP_sha256();
|
||||
case Digest::ID::SHA384:
|
||||
return EVP_sha384();
|
||||
case Digest::ID::SHA512:
|
||||
return EVP_sha512();
|
||||
default:
|
||||
throw std::runtime_error("Unsupported digest");
|
||||
}
|
||||
}
|
||||
|
||||
Signature::ID
|
||||
RSASignature::digest_to_sig(Digest::ID digest)
|
||||
{
|
||||
// NOLINTNEXTLINE(hicpp-multiway-paths-covered)
|
||||
switch (digest) {
|
||||
case Digest::ID::SHA256:
|
||||
return Signature::ID::RSA_SHA256;
|
||||
case Digest::ID::SHA384:
|
||||
return Signature::ID::RSA_SHA384;
|
||||
case Digest::ID::SHA512:
|
||||
return Signature::ID::RSA_SHA512;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported digest");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
97
DPP/mlspp/lib/hpke/src/rsa.h
Executable file
97
DPP/mlspp/lib/hpke/src/rsa.h
Executable file
@@ -0,0 +1,97 @@
|
||||
#pragma once
|
||||
|
||||
#include <hpke/digest.h>
|
||||
#include <hpke/hpke.h>
|
||||
#include <hpke/signature.h>
|
||||
|
||||
#include "openssl_common.h"
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/rsa.h>
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
// XXX(RLB): There is a lot of code in RSASignature that is duplicated in
|
||||
// EVPGroup. I have allowed this duplication rather than factoring it out
|
||||
// because I would like to be able to cleanly remove RSA later.
|
||||
struct RSASignature : public Signature
|
||||
{
|
||||
struct PublicKey : public Signature::PublicKey
|
||||
{
|
||||
explicit PublicKey(EVP_PKEY* pkey_in)
|
||||
: pkey(pkey_in, typed_delete<EVP_PKEY>)
|
||||
{
|
||||
}
|
||||
|
||||
~PublicKey() override = default;
|
||||
|
||||
typed_unique_ptr<EVP_PKEY> pkey;
|
||||
};
|
||||
|
||||
struct PrivateKey : public Signature::PrivateKey
|
||||
{
|
||||
explicit PrivateKey(EVP_PKEY* pkey_in)
|
||||
: pkey(pkey_in, typed_delete<EVP_PKEY>)
|
||||
{
|
||||
}
|
||||
|
||||
~PrivateKey() override = default;
|
||||
|
||||
std::unique_ptr<Signature::PublicKey> public_key() const override
|
||||
{
|
||||
if (1 != EVP_PKEY_up_ref(pkey.get())) {
|
||||
throw openssl_error();
|
||||
}
|
||||
return std::make_unique<PublicKey>(pkey.get());
|
||||
}
|
||||
|
||||
typed_unique_ptr<EVP_PKEY> pkey;
|
||||
};
|
||||
|
||||
explicit RSASignature(Digest::ID digest)
|
||||
: Signature(digest_to_sig(digest))
|
||||
, md(digest_to_md(digest))
|
||||
{
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey> generate_key_pair() const override;
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey> derive_key_pair(
|
||||
const bytes& /*ikm*/) const override;
|
||||
|
||||
static std::unique_ptr<Signature::PrivateKey> generate_key_pair(size_t bits);
|
||||
|
||||
// TODO(rlb): Implement derive() with sizes
|
||||
|
||||
bytes serialize(const Signature::PublicKey& pk) const override;
|
||||
|
||||
std::unique_ptr<Signature::PublicKey> deserialize(
|
||||
const bytes& enc) const override;
|
||||
|
||||
bytes serialize_private(const Signature::PrivateKey& sk) const override;
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey> deserialize_private(
|
||||
const bytes& skm) const override;
|
||||
|
||||
bytes sign(const bytes& data, const Signature::PrivateKey& sk) const override;
|
||||
|
||||
bool verify(const bytes& data,
|
||||
const bytes& sig,
|
||||
const Signature::PublicKey& pk) const override;
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey> import_jwk_private(
|
||||
const std::string& json_str) const override;
|
||||
std::unique_ptr<Signature::PublicKey> import_jwk(
|
||||
const std::string& json_str) const override;
|
||||
std::string export_jwk_private(
|
||||
const Signature::PrivateKey& sk) const override;
|
||||
std::string export_jwk(const Signature::PublicKey& pk) const override;
|
||||
|
||||
private:
|
||||
const EVP_MD* md;
|
||||
|
||||
static const EVP_MD* digest_to_md(Digest::ID digest);
|
||||
|
||||
static Signature::ID digest_to_sig(Digest::ID digest);
|
||||
};
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
344
DPP/mlspp/lib/hpke/src/signature.cpp
Executable file
344
DPP/mlspp/lib/hpke/src/signature.cpp
Executable file
@@ -0,0 +1,344 @@
|
||||
#include <hpke/base64.h>
|
||||
#include <hpke/digest.h>
|
||||
#include <hpke/signature.h>
|
||||
#include <string>
|
||||
|
||||
#include "dhkem.h"
|
||||
#include "rsa.h"
|
||||
|
||||
#include <dpp/json.h>
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/bn.h>
|
||||
#include <openssl/ec.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/rsa.h>
|
||||
|
||||
using nlohmann::json;
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
struct GroupSignature : public Signature
|
||||
{
|
||||
struct PrivateKey : public Signature::PrivateKey
|
||||
{
|
||||
explicit PrivateKey(Group::PrivateKey* group_priv_in)
|
||||
: group_priv(group_priv_in)
|
||||
{
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PublicKey> public_key() const override
|
||||
{
|
||||
return group_priv->public_key();
|
||||
}
|
||||
|
||||
std::unique_ptr<Group::PrivateKey> group_priv;
|
||||
};
|
||||
|
||||
static Signature::ID group_to_sig(Group::ID group_id)
|
||||
{
|
||||
switch (group_id) {
|
||||
case Group::ID::P256:
|
||||
return Signature::ID::P256_SHA256;
|
||||
case Group::ID::P384:
|
||||
return Signature::ID::P384_SHA384;
|
||||
case Group::ID::P521:
|
||||
return Signature::ID::P521_SHA512;
|
||||
case Group::ID::Ed25519:
|
||||
return Signature::ID::Ed25519;
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
case Group::ID::Ed448:
|
||||
return Signature::ID::Ed448;
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error("Unsupported group");
|
||||
}
|
||||
}
|
||||
|
||||
explicit GroupSignature(const Group& group_in)
|
||||
: Signature(group_to_sig(group_in.id))
|
||||
, group(group_in)
|
||||
{
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey> generate_key_pair() const override
|
||||
{
|
||||
return std::make_unique<PrivateKey>(group.generate_key_pair().release());
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey> derive_key_pair(
|
||||
const bytes& ikm) const override
|
||||
{
|
||||
return std::make_unique<PrivateKey>(
|
||||
group.derive_key_pair({}, ikm).release());
|
||||
}
|
||||
|
||||
bytes serialize(const Signature::PublicKey& pk) const override
|
||||
{
|
||||
const auto& rpk = dynamic_cast<const Group::PublicKey&>(pk);
|
||||
return group.serialize(rpk);
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PublicKey> deserialize(
|
||||
const bytes& enc) const override
|
||||
{
|
||||
return group.deserialize(enc);
|
||||
}
|
||||
|
||||
bytes serialize_private(const Signature::PrivateKey& sk) const override
|
||||
{
|
||||
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
|
||||
return group.serialize_private(*rsk.group_priv);
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey> deserialize_private(
|
||||
const bytes& skm) const override
|
||||
{
|
||||
return std::make_unique<PrivateKey>(
|
||||
group.deserialize_private(skm).release());
|
||||
}
|
||||
|
||||
bytes sign(const bytes& data, const Signature::PrivateKey& sk) const override
|
||||
{
|
||||
const auto& rsk = dynamic_cast<const PrivateKey&>(sk);
|
||||
return group.sign(data, *rsk.group_priv);
|
||||
}
|
||||
|
||||
bool verify(const bytes& data,
|
||||
const bytes& sig,
|
||||
const Signature::PublicKey& pk) const override
|
||||
{
|
||||
const auto& rpk = dynamic_cast<const Group::PublicKey&>(pk);
|
||||
return group.verify(data, sig, rpk);
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey> import_jwk_private(
|
||||
const std::string& jwk_json) const override
|
||||
{
|
||||
const auto jwk = validate_jwk_json(jwk_json, true);
|
||||
|
||||
const auto d = from_base64url(jwk.at("d"));
|
||||
auto gsk = group.deserialize_private(d);
|
||||
|
||||
return std::make_unique<PrivateKey>(gsk.release());
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PublicKey> import_jwk(
|
||||
const std::string& jwk_json) const override
|
||||
{
|
||||
const auto jwk = validate_jwk_json(jwk_json, false);
|
||||
|
||||
const auto x = from_base64url(jwk.at("x"));
|
||||
auto y = bytes{};
|
||||
if (jwk.contains("y")) {
|
||||
y = from_base64url(jwk.at("y"));
|
||||
}
|
||||
|
||||
return group.public_key_from_coordinates(x, y);
|
||||
}
|
||||
|
||||
std::string export_jwk(const Signature::PublicKey& pk) const override
|
||||
{
|
||||
const auto& gpk = dynamic_cast<const Group::PublicKey&>(pk);
|
||||
const auto jwk_json = export_jwk_json(gpk);
|
||||
return jwk_json.dump();
|
||||
}
|
||||
|
||||
std::string export_jwk_private(const Signature::PrivateKey& sk) const override
|
||||
{
|
||||
const auto& gssk = dynamic_cast<const GroupSignature::PrivateKey&>(sk);
|
||||
const auto& gsk = gssk.group_priv;
|
||||
const auto gpk = gsk->public_key();
|
||||
|
||||
auto jwk_json = export_jwk_json(*gpk);
|
||||
|
||||
// encode the private key
|
||||
const auto enc = serialize_private(sk);
|
||||
jwk_json.emplace("d", to_base64url(enc));
|
||||
|
||||
return jwk_json.dump();
|
||||
}
|
||||
|
||||
private:
|
||||
const Group& group;
|
||||
|
||||
json validate_jwk_json(const std::string& jwk_json, bool private_key) const
|
||||
{
|
||||
json jwk = json::parse(jwk_json);
|
||||
|
||||
if (jwk.empty() || !jwk.contains("kty") || !jwk.contains("crv") ||
|
||||
!jwk.contains("x") || (private_key && !jwk.contains("d"))) {
|
||||
throw std::runtime_error("malformed JWK");
|
||||
}
|
||||
|
||||
if (jwk.at("kty") != group.jwk_key_type) {
|
||||
throw std::runtime_error("invalid JWK key type");
|
||||
}
|
||||
|
||||
if (jwk.at("crv") != group.jwk_curve_name) {
|
||||
throw std::runtime_error("invalid JWK curve");
|
||||
}
|
||||
|
||||
return jwk;
|
||||
}
|
||||
|
||||
json export_jwk_json(const Group::PublicKey& pk) const
|
||||
{
|
||||
const auto [x, y] = group.coordinates(pk);
|
||||
|
||||
json jwk = json::object({
|
||||
{ "crv", group.jwk_curve_name },
|
||||
{ "kty", group.jwk_key_type },
|
||||
});
|
||||
|
||||
if (group.jwk_key_type == "EC") {
|
||||
jwk.emplace("x", to_base64url(x));
|
||||
jwk.emplace("y", to_base64url(y));
|
||||
} else if (group.jwk_key_type == "OKP") {
|
||||
jwk.emplace("x", to_base64url(x));
|
||||
} else {
|
||||
throw std::runtime_error("unknown key type");
|
||||
}
|
||||
|
||||
return jwk;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
const Signature&
|
||||
Signature::get<Signature::ID::P256_SHA256>()
|
||||
{
|
||||
static const auto instance = GroupSignature(Group::get<Group::ID::P256>());
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const Signature&
|
||||
Signature::get<Signature::ID::P384_SHA384>()
|
||||
{
|
||||
static const auto instance = GroupSignature(Group::get<Group::ID::P384>());
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const Signature&
|
||||
Signature::get<Signature::ID::P521_SHA512>()
|
||||
{
|
||||
static const auto instance = GroupSignature(Group::get<Group::ID::P521>());
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const Signature&
|
||||
Signature::get<Signature::ID::Ed25519>()
|
||||
{
|
||||
static const auto instance = GroupSignature(Group::get<Group::ID::Ed25519>());
|
||||
return instance;
|
||||
}
|
||||
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
template<>
|
||||
const Signature&
|
||||
Signature::get<Signature::ID::Ed448>()
|
||||
{
|
||||
static const auto instance = GroupSignature(Group::get<Group::ID::Ed448>());
|
||||
return instance;
|
||||
}
|
||||
#endif
|
||||
|
||||
template<>
|
||||
const Signature&
|
||||
Signature::get<Signature::ID::RSA_SHA256>()
|
||||
{
|
||||
static const auto instance = RSASignature(Digest::ID::SHA256);
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const Signature&
|
||||
Signature::get<Signature::ID::RSA_SHA384>()
|
||||
{
|
||||
static const auto instance = RSASignature(Digest::ID::SHA384);
|
||||
return instance;
|
||||
}
|
||||
|
||||
template<>
|
||||
const Signature&
|
||||
Signature::get<Signature::ID::RSA_SHA512>()
|
||||
{
|
||||
static const auto instance = RSASignature(Digest::ID::SHA512);
|
||||
return instance;
|
||||
}
|
||||
|
||||
Signature::Signature(Signature::ID id_in)
|
||||
: id(id_in)
|
||||
{
|
||||
}
|
||||
|
||||
std::unique_ptr<Signature::PrivateKey>
|
||||
Signature::generate_rsa(size_t bits)
|
||||
{
|
||||
return RSASignature::generate_key_pair(bits);
|
||||
}
|
||||
|
||||
static const Signature&
|
||||
sig_from_jwk(const std::string& jwk_json)
|
||||
{
|
||||
using KeyTypeAndCurve = std::tuple<std::string, std::string>;
|
||||
static const auto alg_sig_map = std::map<KeyTypeAndCurve, const Signature&>
|
||||
{
|
||||
{ { "EC", "P-256" }, Signature::get<Signature::ID::P256_SHA256>() },
|
||||
{ { "EC", "P-384" }, Signature::get<Signature::ID::P384_SHA384>() },
|
||||
{ { "EC", "P-512" }, Signature::get<Signature::ID::P521_SHA512>() },
|
||||
{ { "OKP", "Ed25519" }, Signature::get<Signature::ID::Ed25519>() },
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
{ { "OKP", "Ed448" }, Signature::get<Signature::ID::Ed448>() },
|
||||
#endif
|
||||
// TODO(RLB): RSA
|
||||
};
|
||||
|
||||
const auto jwk = json::parse(jwk_json);
|
||||
const auto& kty = jwk.at("kty");
|
||||
|
||||
auto crv = std::string("");
|
||||
if (jwk.contains("crv")) {
|
||||
crv = jwk.at("crv");
|
||||
}
|
||||
|
||||
const auto key = KeyTypeAndCurve{ kty, crv };
|
||||
return alg_sig_map.at(key);
|
||||
}
|
||||
|
||||
Signature::PrivateJWK
|
||||
Signature::parse_jwk_private(const std::string& jwk_json)
|
||||
{
|
||||
// XXX(RLB): This JSON-parses the JWK twice. I'm assuming that this is a less
|
||||
// bad cost than changing the import_jwk method signature to take `json`.
|
||||
const auto& sig = sig_from_jwk(jwk_json);
|
||||
const auto jwk = json::parse(jwk_json);
|
||||
auto priv = sig.import_jwk_private(jwk_json);
|
||||
|
||||
auto kid = std::optional<std::string>{};
|
||||
if (jwk.contains("kid")) {
|
||||
kid = jwk.at("kid").get<std::string>();
|
||||
}
|
||||
|
||||
return { sig, kid, std::move(priv) };
|
||||
}
|
||||
|
||||
Signature::PublicJWK
|
||||
Signature::parse_jwk(const std::string& jwk_json)
|
||||
{
|
||||
// XXX(RLB): Same double-parsing comment as with `parse_jwk_private`
|
||||
const auto& sig = sig_from_jwk(jwk_json);
|
||||
const auto jwk = json::parse(jwk_json);
|
||||
auto pub = sig.import_jwk(jwk_json);
|
||||
|
||||
auto kid = std::optional<std::string>{};
|
||||
if (jwk.contains("kid")) {
|
||||
kid = jwk.at("kid").get<std::string>();
|
||||
}
|
||||
|
||||
return { sig, kid, std::move(pub) };
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
401
DPP/mlspp/lib/hpke/src/userinfo_vc.cpp
Executable file
401
DPP/mlspp/lib/hpke/src/userinfo_vc.cpp
Executable file
@@ -0,0 +1,401 @@
|
||||
#include <hpke/base64.h>
|
||||
#include <hpke/signature.h>
|
||||
#include <hpke/userinfo_vc.h>
|
||||
#include <dpp/json.h>
|
||||
#include <tls/compat.h>
|
||||
|
||||
using nlohmann::json;
|
||||
|
||||
namespace mlspp::hpke {
|
||||
|
||||
static const std::string name_attr = "name";
|
||||
static const std::string sub_attr = "sub";
|
||||
static const std::string given_name_attr = "given_name";
|
||||
static const std::string family_name_attr = "family_name";
|
||||
static const std::string middle_name_attr = "middle_name";
|
||||
static const std::string nickname_attr = "nickname";
|
||||
static const std::string preferred_username_attr = "preferred_username";
|
||||
static const std::string profile_attr = "profile";
|
||||
static const std::string picture_attr = "picture";
|
||||
static const std::string website_attr = "website";
|
||||
static const std::string email_attr = "email";
|
||||
static const std::string email_verified_attr = "email_verified";
|
||||
static const std::string gender_attr = "gender";
|
||||
static const std::string birthdate_attr = "birthdate";
|
||||
static const std::string zoneinfo_attr = "zoneinfo";
|
||||
static const std::string locale_attr = "locale";
|
||||
static const std::string phone_number_attr = "phone_number";
|
||||
static const std::string phone_number_verified_attr = "phone_number_verified";
|
||||
static const std::string address_attr = "address";
|
||||
static const std::string address_formatted_attr = "formatted";
|
||||
static const std::string address_street_address_attr = "street_address";
|
||||
static const std::string address_locality_attr = "locality";
|
||||
static const std::string address_region_attr = "region";
|
||||
static const std::string address_postal_code_attr = "postal_code";
|
||||
static const std::string address_country_attr = "country";
|
||||
static const std::string updated_at_attr = "updated_at";
|
||||
|
||||
template<typename T>
|
||||
static std::optional<T>
|
||||
get_optional(const json& json_object, const std::string& field_name)
|
||||
{
|
||||
if (!json_object.contains(field_name)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return { json_object.at(field_name).get<T>() };
|
||||
}
|
||||
|
||||
///
|
||||
/// ParsedCredential
|
||||
///
|
||||
static const Signature&
|
||||
signature_from_alg(const std::string& alg)
|
||||
{
|
||||
static const auto alg_sig_map = std::map<std::string, const Signature&>
|
||||
{
|
||||
{ "ES256", Signature::get<Signature::ID::P256_SHA256>() },
|
||||
{ "ES384", Signature::get<Signature::ID::P384_SHA384>() },
|
||||
{ "ES512", Signature::get<Signature::ID::P521_SHA512>() },
|
||||
{ "Ed25519", Signature::get<Signature::ID::Ed25519>() },
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
{ "Ed448", Signature::get<Signature::ID::Ed448>() },
|
||||
#endif
|
||||
{ "RS256", Signature::get<Signature::ID::RSA_SHA256>() },
|
||||
{ "RS384", Signature::get<Signature::ID::RSA_SHA384>() },
|
||||
{ "RS512", Signature::get<Signature::ID::RSA_SHA512>() },
|
||||
};
|
||||
|
||||
return alg_sig_map.at(alg);
|
||||
}
|
||||
|
||||
static std::chrono::system_clock::time_point
|
||||
epoch_time(int64_t seconds_since_epoch)
|
||||
{
|
||||
const auto delta = std::chrono::seconds(seconds_since_epoch);
|
||||
return std::chrono::system_clock::time_point(delta);
|
||||
}
|
||||
|
||||
static bool
|
||||
is_ecdsa(const Signature& sig)
|
||||
{
|
||||
return sig.id == Signature::ID::P256_SHA256 ||
|
||||
sig.id == Signature::ID::P384_SHA384 ||
|
||||
sig.id == Signature::ID::P521_SHA512;
|
||||
}
|
||||
|
||||
// OpenSSL expects ECDSA signatures to be in DER form. JWS provides the
|
||||
// signature in raw R||S form. So we need to do some manual DER encoding.
|
||||
static bytes
|
||||
jws_to_der_sig(const bytes& jws_sig)
|
||||
{
|
||||
// Inputs that are too large will result in invalid DER encodings with this
|
||||
// code. At this size, the combination of the DER integer headers and the
|
||||
// integer data will overflow the one-byte DER struct length.
|
||||
static const auto max_sig_size = size_t(250);
|
||||
if (jws_sig.size() > max_sig_size) {
|
||||
throw std::runtime_error("JWS signature too large");
|
||||
}
|
||||
|
||||
if (jws_sig.size() % 2 != 0) {
|
||||
throw std::runtime_error("Malformed JWS signature");
|
||||
}
|
||||
|
||||
const auto int_size = jws_sig.size() / 2;
|
||||
const auto jws_sig_cut =
|
||||
jws_sig.begin() + static_cast<std::ptrdiff_t>(int_size);
|
||||
|
||||
// Compute the encoded size of R and S integer data, adding a zero byte if
|
||||
// needed to clear the sign bit
|
||||
const auto r_big = (jws_sig.at(0) >= 0x80);
|
||||
const auto s_big = (jws_sig.at(int_size) >= 0x80);
|
||||
|
||||
const auto r_size = int_size + (r_big ? 1 : 0);
|
||||
const auto s_size = int_size + (s_big ? 1 : 0);
|
||||
|
||||
// Compute the size of the DER-encoded signature
|
||||
static const auto int_header_size = 2;
|
||||
const auto r_int_size = int_header_size + r_size;
|
||||
const auto s_int_size = int_header_size + s_size;
|
||||
|
||||
const auto content_size = r_int_size + s_int_size;
|
||||
const auto content_big = (content_size > 0x80);
|
||||
|
||||
auto der_header_size = 2 + (content_big ? 1 : 0);
|
||||
const auto der_size = der_header_size + content_size;
|
||||
|
||||
// Allocate the DER buffer
|
||||
auto der = bytes(der_size, 0);
|
||||
|
||||
// Write the header
|
||||
der.at(0) = 0x30;
|
||||
if (content_big) {
|
||||
der.at(1) = 0x81;
|
||||
der.at(2) = static_cast<uint8_t>(content_size);
|
||||
} else {
|
||||
der.at(1) = static_cast<uint8_t>(content_size);
|
||||
}
|
||||
|
||||
// Write R, virtually padding with a zero byte if needed
|
||||
const auto r_start = der_header_size;
|
||||
const auto r_data_start = r_start + int_header_size + (r_big ? 1 : 0);
|
||||
const auto r_data_begin =
|
||||
der.begin() + static_cast<std::ptrdiff_t>(r_data_start);
|
||||
|
||||
der.at(r_start) = 0x02;
|
||||
der.at(r_start + 1) = static_cast<uint8_t>(r_size);
|
||||
std::copy(jws_sig.begin(), jws_sig_cut, r_data_begin);
|
||||
|
||||
// Write S, virtually padding with a zero byte if needed
|
||||
const auto s_start = der_header_size + r_int_size;
|
||||
const auto s_data_start = s_start + int_header_size + (s_big ? 1 : 0);
|
||||
const auto s_data_begin =
|
||||
der.begin() + static_cast<std::ptrdiff_t>(s_data_start);
|
||||
|
||||
der.at(s_start) = 0x02;
|
||||
der.at(s_start + 1) = static_cast<uint8_t>(s_size);
|
||||
std::copy(jws_sig_cut, jws_sig.end(), s_data_begin);
|
||||
|
||||
return der;
|
||||
}
|
||||
|
||||
struct UserInfoVC::ParsedCredential
|
||||
{
|
||||
// Header fields
|
||||
const Signature& signature_algorithm; // `alg`
|
||||
std::optional<std::string> key_id; // `kid`
|
||||
|
||||
// Top-level Payload fields
|
||||
std::string issuer; // `iss`
|
||||
std::chrono::system_clock::time_point not_before; // `nbf`
|
||||
std::chrono::system_clock::time_point not_after; // `exp`
|
||||
|
||||
// Credential subject fields
|
||||
UserInfoClaims credential_subject;
|
||||
Signature::PublicJWK public_key;
|
||||
|
||||
// Signature verification information
|
||||
bytes to_be_signed;
|
||||
bytes signature;
|
||||
|
||||
ParsedCredential(const Signature& signature_algorithm_in,
|
||||
std::optional<std::string> key_id_in,
|
||||
std::string issuer_in,
|
||||
std::chrono::system_clock::time_point not_before_in,
|
||||
std::chrono::system_clock::time_point not_after_in,
|
||||
UserInfoClaims credential_subject_in,
|
||||
Signature::PublicJWK&& public_key_in,
|
||||
bytes to_be_signed_in,
|
||||
bytes signature_in)
|
||||
: signature_algorithm(signature_algorithm_in)
|
||||
, key_id(std::move(key_id_in))
|
||||
, issuer(std::move(issuer_in))
|
||||
, not_before(not_before_in)
|
||||
, not_after(not_after_in)
|
||||
, credential_subject(std::move(credential_subject_in))
|
||||
, public_key(std::move(public_key_in))
|
||||
, to_be_signed(std::move(to_be_signed_in))
|
||||
, signature(std::move(signature_in))
|
||||
{
|
||||
}
|
||||
|
||||
static std::shared_ptr<ParsedCredential> parse(const std::string& jwt)
|
||||
{
|
||||
// Split the JWT into its header, payload, and signature
|
||||
const auto first_dot = jwt.find_first_of('.');
|
||||
const auto last_dot = jwt.find_last_of('.');
|
||||
if (first_dot == std::string::npos || last_dot == std::string::npos ||
|
||||
first_dot == last_dot || last_dot > jwt.length() - 2) {
|
||||
throw std::runtime_error("malformed JWT; not enough '.' characters");
|
||||
}
|
||||
|
||||
const auto header_b64 = jwt.substr(0, first_dot);
|
||||
const auto payload_b64 =
|
||||
jwt.substr(first_dot + 1, last_dot - first_dot - 1);
|
||||
const auto signature_b64 = jwt.substr(last_dot + 1);
|
||||
|
||||
// Parse the components
|
||||
const auto header = json::parse(to_ascii(from_base64url(header_b64)));
|
||||
const auto payload = json::parse(to_ascii(from_base64url(payload_b64)));
|
||||
|
||||
// Prepare the validation inputs
|
||||
const auto hdr = header.at("alg");
|
||||
const auto& sig = signature_from_alg(hdr);
|
||||
const auto to_be_signed = from_ascii(header_b64 + "." + payload_b64);
|
||||
auto signature = from_base64url(signature_b64);
|
||||
if (is_ecdsa(sig)) {
|
||||
signature = jws_to_der_sig(signature);
|
||||
}
|
||||
|
||||
auto kid = std::optional<std::string>{};
|
||||
if (header.contains("kid")) {
|
||||
kid = header.at("kid").get<std::string>();
|
||||
}
|
||||
|
||||
// Verify the VC parts
|
||||
const auto& vc = payload.at("vc");
|
||||
|
||||
static const auto context =
|
||||
std::vector<std::string>{ { "https://www.w3.org/2018/credentials/v1" } };
|
||||
const auto vc_context = vc.at("@context").get<std::vector<std::string>>();
|
||||
if (vc_context != context) {
|
||||
throw std::runtime_error("malformed VC: incorrect context value");
|
||||
}
|
||||
|
||||
static const auto type = std::vector<std::string>{
|
||||
"VerifiableCredential",
|
||||
"UserInfoCredential",
|
||||
};
|
||||
if (vc.at("type") != type) {
|
||||
throw std::runtime_error("malformed VC: incorrect type value");
|
||||
}
|
||||
|
||||
// Parse the subject public key
|
||||
static const std::string did_jwk_prefix = "did:jwk:";
|
||||
const auto id = vc.at("credentialSubject").at("id").get<std::string>();
|
||||
if (id.find(did_jwk_prefix) != 0) {
|
||||
throw std::runtime_error("malformed UserInfo VC: ID is not did:jwk");
|
||||
}
|
||||
|
||||
const auto jwk = to_ascii(from_base64url(id.substr(did_jwk_prefix.size())));
|
||||
auto public_key = Signature::parse_jwk(jwk);
|
||||
|
||||
// Extract the salient parts
|
||||
return std::make_shared<ParsedCredential>(
|
||||
sig,
|
||||
kid,
|
||||
|
||||
payload.at("iss"),
|
||||
epoch_time(payload.at("nbf").get<int64_t>()),
|
||||
epoch_time(payload.at("exp").get<int64_t>()),
|
||||
|
||||
UserInfoClaims::from_json(vc.at("credentialSubject").dump()),
|
||||
std::move(public_key),
|
||||
|
||||
to_be_signed,
|
||||
signature);
|
||||
}
|
||||
|
||||
bool verify(const Signature::PublicKey& issuer_key)
|
||||
{
|
||||
return signature_algorithm.verify(to_be_signed, signature, issuer_key);
|
||||
}
|
||||
};
|
||||
|
||||
///
|
||||
/// UserInfoClaims
|
||||
///
|
||||
UserInfoClaims
|
||||
UserInfoClaims::from_json(const std::string& cred_subject)
|
||||
{
|
||||
const auto& cred_subject_json = nlohmann::json::parse(cred_subject);
|
||||
|
||||
std::optional<UserInfoClaimsAddress> address_opt = {};
|
||||
|
||||
if (cred_subject_json.contains(address_attr)) {
|
||||
auto address_json = cred_subject_json.at(address_attr);
|
||||
address_opt = {
|
||||
get_optional<std::string>(address_json, address_formatted_attr),
|
||||
get_optional<std::string>(address_json, address_street_address_attr),
|
||||
get_optional<std::string>(address_json, address_locality_attr),
|
||||
get_optional<std::string>(address_json, address_region_attr),
|
||||
get_optional<std::string>(address_json, address_postal_code_attr),
|
||||
get_optional<std::string>(address_json, address_country_attr)
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
get_optional<std::string>(cred_subject_json, sub_attr),
|
||||
get_optional<std::string>(cred_subject_json, name_attr),
|
||||
get_optional<std::string>(cred_subject_json, given_name_attr),
|
||||
get_optional<std::string>(cred_subject_json, family_name_attr),
|
||||
get_optional<std::string>(cred_subject_json, middle_name_attr),
|
||||
get_optional<std::string>(cred_subject_json, nickname_attr),
|
||||
get_optional<std::string>(cred_subject_json, preferred_username_attr),
|
||||
get_optional<std::string>(cred_subject_json, profile_attr),
|
||||
get_optional<std::string>(cred_subject_json, picture_attr),
|
||||
get_optional<std::string>(cred_subject_json, website_attr),
|
||||
get_optional<std::string>(cred_subject_json, email_attr),
|
||||
get_optional<bool>(cred_subject_json, email_verified_attr),
|
||||
get_optional<std::string>(cred_subject_json, gender_attr),
|
||||
get_optional<std::string>(cred_subject_json, birthdate_attr),
|
||||
get_optional<std::string>(cred_subject_json, zoneinfo_attr),
|
||||
get_optional<std::string>(cred_subject_json, locale_attr),
|
||||
get_optional<std::string>(cred_subject_json, phone_number_attr),
|
||||
get_optional<bool>(cred_subject_json, phone_number_verified_attr),
|
||||
address_opt,
|
||||
get_optional<uint64_t>(cred_subject_json, updated_at_attr),
|
||||
};
|
||||
}
|
||||
|
||||
///
|
||||
/// UserInfoVC
|
||||
///
|
||||
|
||||
UserInfoVC::UserInfoVC(std::string jwt)
|
||||
: parsed_cred(ParsedCredential::parse(jwt))
|
||||
, raw(std::move(jwt))
|
||||
{
|
||||
}
|
||||
|
||||
const Signature&
|
||||
UserInfoVC::signature_algorithm() const
|
||||
{
|
||||
return parsed_cred->signature_algorithm;
|
||||
}
|
||||
|
||||
std::string
|
||||
UserInfoVC::issuer() const
|
||||
{
|
||||
return parsed_cred->issuer;
|
||||
}
|
||||
|
||||
std::optional<std::string>
|
||||
UserInfoVC::key_id() const
|
||||
{
|
||||
return parsed_cred->key_id;
|
||||
}
|
||||
|
||||
bool
|
||||
UserInfoVC::valid_from(const Signature::PublicKey& issuer_key) const
|
||||
{
|
||||
return parsed_cred->verify(issuer_key);
|
||||
}
|
||||
|
||||
const std::string&
|
||||
UserInfoVC::raw_credential() const
|
||||
{
|
||||
return raw;
|
||||
}
|
||||
|
||||
const UserInfoClaims&
|
||||
UserInfoVC::subject() const
|
||||
{
|
||||
return parsed_cred->credential_subject;
|
||||
}
|
||||
|
||||
std::chrono::system_clock::time_point
|
||||
UserInfoVC::not_before() const
|
||||
{
|
||||
return parsed_cred->not_before;
|
||||
}
|
||||
|
||||
std::chrono::system_clock::time_point
|
||||
UserInfoVC::not_after() const
|
||||
{
|
||||
return parsed_cred->not_after;
|
||||
}
|
||||
|
||||
const Signature::PublicJWK&
|
||||
UserInfoVC::public_key() const
|
||||
{
|
||||
return parsed_cred->public_key;
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const UserInfoVC& lhs, const UserInfoVC& rhs)
|
||||
{
|
||||
return lhs.raw == rhs.raw;
|
||||
}
|
||||
|
||||
} // namespace mlspp::hpke
|
||||
24
DPP/mlspp/lib/mls_vectors/CMakeLists.txt
Executable file
24
DPP/mlspp/lib/mls_vectors/CMakeLists.txt
Executable file
@@ -0,0 +1,24 @@
|
||||
set(CURRENT_LIB_NAME mls_vectors)
|
||||
|
||||
###
|
||||
### Library Config
|
||||
###
|
||||
|
||||
file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
|
||||
file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
|
||||
|
||||
add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES})
|
||||
add_dependencies(${CURRENT_LIB_NAME} mlspp)
|
||||
target_link_libraries(${CURRENT_LIB_NAME} mlspp bytes tls_syntax)
|
||||
target_include_directories(${CURRENT_LIB_NAME}
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include/${PROJECT_NAME}>
|
||||
)
|
||||
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include")
|
||||
|
||||
577
DPP/mlspp/lib/mls_vectors/include/mls_vectors/mls_vectors.h
Executable file
577
DPP/mlspp/lib/mls_vectors/include/mls_vectors/mls_vectors.h
Executable file
@@ -0,0 +1,577 @@
|
||||
#pragma once
|
||||
|
||||
#include <bytes/bytes.h>
|
||||
#include <mls/crypto.h>
|
||||
#include <mls/key_schedule.h>
|
||||
#include <mls/messages.h>
|
||||
#include <mls/tree_math.h>
|
||||
#include <mls/treekem.h>
|
||||
#include <tls/tls_syntax.h>
|
||||
#include <vector>
|
||||
|
||||
namespace mls_vectors {
|
||||
|
||||
struct PseudoRandom
|
||||
{
|
||||
struct Generator
|
||||
{
|
||||
Generator() = default;
|
||||
Generator(mlspp::CipherSuite suite_in, const std::string& label);
|
||||
Generator sub(const std::string& label) const;
|
||||
|
||||
bytes secret(const std::string& label) const;
|
||||
bytes generate(const std::string& label, size_t size) const;
|
||||
|
||||
uint16_t uint16(const std::string& label) const;
|
||||
uint32_t uint32(const std::string& label) const;
|
||||
uint64_t uint64(const std::string& label) const;
|
||||
|
||||
mlspp::SignaturePrivateKey signature_key(
|
||||
const std::string& label) const;
|
||||
mlspp::HPKEPrivateKey hpke_key(const std::string& label) const;
|
||||
|
||||
size_t output_length() const;
|
||||
|
||||
private:
|
||||
mlspp::CipherSuite suite;
|
||||
bytes seed;
|
||||
|
||||
Generator(mlspp::CipherSuite suite_in, bytes seed_in);
|
||||
};
|
||||
|
||||
PseudoRandom() = default;
|
||||
PseudoRandom(mlspp::CipherSuite suite, const std::string& label);
|
||||
|
||||
Generator prg;
|
||||
};
|
||||
|
||||
struct TreeMathTestVector
|
||||
{
|
||||
using OptionalNode = std::optional<mlspp::NodeIndex>;
|
||||
|
||||
mlspp::LeafCount n_leaves;
|
||||
mlspp::NodeCount n_nodes;
|
||||
mlspp::NodeIndex root;
|
||||
std::vector<OptionalNode> left;
|
||||
std::vector<OptionalNode> right;
|
||||
std::vector<OptionalNode> parent;
|
||||
std::vector<OptionalNode> sibling;
|
||||
|
||||
std::optional<mlspp::NodeIndex> null_if_invalid(
|
||||
mlspp::NodeIndex input,
|
||||
mlspp::NodeIndex answer) const;
|
||||
|
||||
TreeMathTestVector() = default;
|
||||
TreeMathTestVector(uint32_t n_leaves);
|
||||
std::optional<std::string> verify() const;
|
||||
};
|
||||
|
||||
struct CryptoBasicsTestVector : PseudoRandom
|
||||
{
|
||||
struct RefHash
|
||||
{
|
||||
std::string label;
|
||||
bytes value;
|
||||
bytes out;
|
||||
|
||||
RefHash() = default;
|
||||
RefHash(mlspp::CipherSuite suite,
|
||||
const PseudoRandom::Generator& prg);
|
||||
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
|
||||
};
|
||||
|
||||
struct ExpandWithLabel
|
||||
{
|
||||
bytes secret;
|
||||
std::string label;
|
||||
bytes context;
|
||||
uint16_t length;
|
||||
bytes out;
|
||||
|
||||
ExpandWithLabel() = default;
|
||||
ExpandWithLabel(mlspp::CipherSuite suite,
|
||||
const PseudoRandom::Generator& prg);
|
||||
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
|
||||
};
|
||||
|
||||
struct DeriveSecret
|
||||
{
|
||||
bytes secret;
|
||||
std::string label;
|
||||
bytes out;
|
||||
|
||||
DeriveSecret() = default;
|
||||
DeriveSecret(mlspp::CipherSuite suite,
|
||||
const PseudoRandom::Generator& prg);
|
||||
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
|
||||
};
|
||||
|
||||
struct DeriveTreeSecret
|
||||
{
|
||||
bytes secret;
|
||||
std::string label;
|
||||
uint32_t generation;
|
||||
uint16_t length;
|
||||
bytes out;
|
||||
|
||||
DeriveTreeSecret() = default;
|
||||
DeriveTreeSecret(mlspp::CipherSuite suite,
|
||||
const PseudoRandom::Generator& prg);
|
||||
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
|
||||
};
|
||||
|
||||
struct SignWithLabel
|
||||
{
|
||||
mlspp::SignaturePrivateKey priv;
|
||||
mlspp::SignaturePublicKey pub;
|
||||
bytes content;
|
||||
std::string label;
|
||||
bytes signature;
|
||||
|
||||
SignWithLabel() = default;
|
||||
SignWithLabel(mlspp::CipherSuite suite,
|
||||
const PseudoRandom::Generator& prg);
|
||||
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
|
||||
};
|
||||
|
||||
struct EncryptWithLabel
|
||||
{
|
||||
mlspp::HPKEPrivateKey priv;
|
||||
mlspp::HPKEPublicKey pub;
|
||||
std::string label;
|
||||
bytes context;
|
||||
bytes plaintext;
|
||||
bytes kem_output;
|
||||
bytes ciphertext;
|
||||
|
||||
EncryptWithLabel() = default;
|
||||
EncryptWithLabel(mlspp::CipherSuite suite,
|
||||
const PseudoRandom::Generator& prg);
|
||||
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
|
||||
};
|
||||
|
||||
mlspp::CipherSuite cipher_suite;
|
||||
|
||||
RefHash ref_hash;
|
||||
ExpandWithLabel expand_with_label;
|
||||
DeriveSecret derive_secret;
|
||||
DeriveTreeSecret derive_tree_secret;
|
||||
SignWithLabel sign_with_label;
|
||||
EncryptWithLabel encrypt_with_label;
|
||||
|
||||
CryptoBasicsTestVector() = default;
|
||||
CryptoBasicsTestVector(mlspp::CipherSuite suite);
|
||||
std::optional<std::string> verify() const;
|
||||
};
|
||||
|
||||
struct SecretTreeTestVector : PseudoRandom
|
||||
{
|
||||
struct SenderData
|
||||
{
|
||||
bytes sender_data_secret;
|
||||
bytes ciphertext;
|
||||
bytes key;
|
||||
bytes nonce;
|
||||
|
||||
SenderData() = default;
|
||||
SenderData(mlspp::CipherSuite suite,
|
||||
const PseudoRandom::Generator& prg);
|
||||
std::optional<std::string> verify(mlspp::CipherSuite suite) const;
|
||||
};
|
||||
|
||||
struct RatchetStep
|
||||
{
|
||||
uint32_t generation;
|
||||
bytes handshake_key;
|
||||
bytes handshake_nonce;
|
||||
bytes application_key;
|
||||
bytes application_nonce;
|
||||
};
|
||||
|
||||
mlspp::CipherSuite cipher_suite;
|
||||
|
||||
SenderData sender_data;
|
||||
|
||||
bytes encryption_secret;
|
||||
std::vector<std::vector<RatchetStep>> leaves;
|
||||
|
||||
SecretTreeTestVector() = default;
|
||||
SecretTreeTestVector(mlspp::CipherSuite suite,
|
||||
uint32_t n_leaves,
|
||||
const std::vector<uint32_t>& generations);
|
||||
std::optional<std::string> verify() const;
|
||||
};
|
||||
|
||||
struct KeyScheduleTestVector : PseudoRandom
|
||||
{
|
||||
struct Export
|
||||
{
|
||||
std::string label;
|
||||
bytes context;
|
||||
size_t length;
|
||||
bytes secret;
|
||||
};
|
||||
|
||||
struct Epoch
|
||||
{
|
||||
// Chosen by the generator
|
||||
bytes tree_hash;
|
||||
bytes commit_secret;
|
||||
bytes psk_secret;
|
||||
bytes confirmed_transcript_hash;
|
||||
|
||||
// Computed values
|
||||
bytes group_context;
|
||||
|
||||
bytes joiner_secret;
|
||||
bytes welcome_secret;
|
||||
bytes init_secret;
|
||||
|
||||
bytes sender_data_secret;
|
||||
bytes encryption_secret;
|
||||
bytes exporter_secret;
|
||||
bytes epoch_authenticator;
|
||||
bytes external_secret;
|
||||
bytes confirmation_key;
|
||||
bytes membership_key;
|
||||
bytes resumption_psk;
|
||||
|
||||
mlspp::HPKEPublicKey external_pub;
|
||||
Export exporter;
|
||||
};
|
||||
|
||||
mlspp::CipherSuite cipher_suite;
|
||||
|
||||
bytes group_id;
|
||||
bytes initial_init_secret;
|
||||
|
||||
std::vector<Epoch> epochs;
|
||||
|
||||
KeyScheduleTestVector() = default;
|
||||
KeyScheduleTestVector(mlspp::CipherSuite suite, uint32_t n_epochs);
|
||||
std::optional<std::string> verify() const;
|
||||
};
|
||||
|
||||
struct MessageProtectionTestVector : PseudoRandom
|
||||
{
|
||||
mlspp::CipherSuite cipher_suite;
|
||||
|
||||
bytes group_id;
|
||||
mlspp::epoch_t epoch;
|
||||
bytes tree_hash;
|
||||
bytes confirmed_transcript_hash;
|
||||
|
||||
mlspp::SignaturePrivateKey signature_priv;
|
||||
mlspp::SignaturePublicKey signature_pub;
|
||||
|
||||
bytes encryption_secret;
|
||||
bytes sender_data_secret;
|
||||
bytes membership_key;
|
||||
|
||||
mlspp::Proposal proposal;
|
||||
mlspp::MLSMessage proposal_pub;
|
||||
mlspp::MLSMessage proposal_priv;
|
||||
|
||||
mlspp::Commit commit;
|
||||
mlspp::MLSMessage commit_pub;
|
||||
mlspp::MLSMessage commit_priv;
|
||||
|
||||
bytes application;
|
||||
mlspp::MLSMessage application_priv;
|
||||
|
||||
MessageProtectionTestVector() = default;
|
||||
MessageProtectionTestVector(mlspp::CipherSuite suite);
|
||||
std::optional<std::string> verify();
|
||||
|
||||
private:
|
||||
mlspp::GroupKeySource group_keys() const;
|
||||
mlspp::GroupContext group_context() const;
|
||||
|
||||
mlspp::MLSMessage protect_pub(
|
||||
const mlspp::GroupContent::RawContent& raw_content) const;
|
||||
mlspp::MLSMessage protect_priv(
|
||||
const mlspp::GroupContent::RawContent& raw_content);
|
||||
std::optional<mlspp::GroupContent> unprotect(
|
||||
const mlspp::MLSMessage& message);
|
||||
};
|
||||
|
||||
struct PSKSecretTestVector : PseudoRandom
|
||||
{
|
||||
struct PSK
|
||||
{
|
||||
bytes psk_id;
|
||||
bytes psk_nonce;
|
||||
bytes psk;
|
||||
};
|
||||
|
||||
mlspp::CipherSuite cipher_suite;
|
||||
std::vector<PSK> psks;
|
||||
bytes psk_secret;
|
||||
|
||||
PSKSecretTestVector() = default;
|
||||
PSKSecretTestVector(mlspp::CipherSuite suite, size_t n_psks);
|
||||
std::optional<std::string> verify() const;
|
||||
};
|
||||
|
||||
struct TranscriptTestVector : PseudoRandom
|
||||
{
|
||||
mlspp::CipherSuite cipher_suite;
|
||||
|
||||
bytes confirmation_key;
|
||||
bytes interim_transcript_hash_before;
|
||||
|
||||
mlspp::AuthenticatedContent authenticated_content;
|
||||
|
||||
bytes confirmed_transcript_hash_after;
|
||||
bytes interim_transcript_hash_after;
|
||||
|
||||
TranscriptTestVector() = default;
|
||||
TranscriptTestVector(mlspp::CipherSuite suite);
|
||||
std::optional<std::string> verify() const;
|
||||
};
|
||||
|
||||
struct WelcomeTestVector : PseudoRandom
|
||||
{
|
||||
mlspp::CipherSuite cipher_suite;
|
||||
|
||||
mlspp::HPKEPrivateKey init_priv;
|
||||
mlspp::SignaturePublicKey signer_pub;
|
||||
|
||||
mlspp::MLSMessage key_package;
|
||||
mlspp::MLSMessage welcome;
|
||||
|
||||
WelcomeTestVector() = default;
|
||||
WelcomeTestVector(mlspp::CipherSuite suite);
|
||||
std::optional<std::string> verify() const;
|
||||
};
|
||||
|
||||
// XXX(RLB): The |structure| of the example trees below is to avoid compile
|
||||
// errors from gcc's -Werror=comment when there is a '\' character at the end of
|
||||
// a line. Inspired by a similar bug in Chromium:
|
||||
// https://codereview.chromium.org/874663003/patch/1/10001
|
||||
enum struct TreeStructure
|
||||
{
|
||||
// Full trees on N leaves, created by member k adding member k+1
|
||||
full_tree_2,
|
||||
full_tree_3,
|
||||
full_tree_4,
|
||||
full_tree_5,
|
||||
full_tree_6,
|
||||
full_tree_7,
|
||||
full_tree_8,
|
||||
full_tree_32,
|
||||
full_tree_33,
|
||||
full_tree_34,
|
||||
|
||||
// | W |
|
||||
// | ______|______ |
|
||||
// | / \ |
|
||||
// | U Y |
|
||||
// | __|__ __|__ |
|
||||
// | / \ / \ |
|
||||
// | T _ X Z |
|
||||
// | / \ / \ / \ / \ |
|
||||
// | A B C _ E F G H |
|
||||
//
|
||||
// * Start with full tree on 8 members
|
||||
// * 0 commits removeing 2 and 3, and adding a new member
|
||||
internal_blanks_no_skipping,
|
||||
|
||||
// | W |
|
||||
// | ______|______ |
|
||||
// | / \ |
|
||||
// | _ Y |
|
||||
// | __|__ __|__ |
|
||||
// | / \ / \ |
|
||||
// | _ _ X Z |
|
||||
// | / \ / \ / \ / \ |
|
||||
// | A _ _ _ E F G H |
|
||||
//
|
||||
// * Start with full tree on 8 members
|
||||
// * 0 commitsremoveing 1, 2, and 3
|
||||
internal_blanks_with_skipping,
|
||||
|
||||
// | W[H] |
|
||||
// | ______|______ |
|
||||
// | / \ |
|
||||
// | U Y[H] |
|
||||
// | __|__ __|__ |
|
||||
// | / \ / \ |
|
||||
// | T V X _ |
|
||||
// | / \ / \ / \ / \ |
|
||||
// | A B C D E F G H |
|
||||
//
|
||||
// * Start with full tree on 7 members
|
||||
// * 0 commits adding a member in a partial Commit (no path)
|
||||
unmerged_leaves_no_skipping,
|
||||
|
||||
// | W [F] |
|
||||
// | ______|______ |
|
||||
// | / \ |
|
||||
// | U Y [F] |
|
||||
// | __|__ __|__ |
|
||||
// | / \ / \ |
|
||||
// | T _ _ _ |
|
||||
// | / \ / \ / \ / \ |
|
||||
// | A B C D E F G _ |
|
||||
//
|
||||
// == Fig. 20 / {{parent-hash-tree}}
|
||||
// * 0 creates group
|
||||
// * 0 adds 1, ..., 6 in a partial Commit
|
||||
// * O commits removing 5
|
||||
// * 4 commits without any proposals
|
||||
// * 0 commits adding a new member in a partial Commit
|
||||
unmerged_leaves_with_skipping,
|
||||
};
|
||||
|
||||
extern std::array<TreeStructure, 14> all_tree_structures;
|
||||
extern std::array<TreeStructure, 11> treekem_test_tree_structures;
|
||||
|
||||
struct TreeHashTestVector : PseudoRandom
|
||||
{
|
||||
mlspp::CipherSuite cipher_suite;
|
||||
bytes group_id;
|
||||
|
||||
mlspp::TreeKEMPublicKey tree;
|
||||
std::vector<bytes> tree_hashes;
|
||||
std::vector<std::vector<mlspp::NodeIndex>> resolutions;
|
||||
|
||||
TreeHashTestVector() = default;
|
||||
TreeHashTestVector(mlspp::CipherSuite suite,
|
||||
TreeStructure tree_structure);
|
||||
std::optional<std::string> verify();
|
||||
};
|
||||
|
||||
struct TreeOperationsTestVector : PseudoRandom
|
||||
{
|
||||
enum struct Scenario
|
||||
{
|
||||
add_right_edge,
|
||||
add_internal,
|
||||
update,
|
||||
remove_right_edge,
|
||||
remove_internal,
|
||||
};
|
||||
|
||||
static const std::vector<Scenario> all_scenarios;
|
||||
|
||||
mlspp::CipherSuite cipher_suite;
|
||||
|
||||
mlspp::TreeKEMPublicKey tree_before;
|
||||
bytes tree_hash_before;
|
||||
|
||||
mlspp::Proposal proposal;
|
||||
mlspp::LeafIndex proposal_sender;
|
||||
|
||||
mlspp::TreeKEMPublicKey tree_after;
|
||||
bytes tree_hash_after;
|
||||
|
||||
TreeOperationsTestVector() = default;
|
||||
TreeOperationsTestVector(mlspp::CipherSuite suite, Scenario scenario);
|
||||
std::optional<std::string> verify();
|
||||
};
|
||||
|
||||
struct TreeKEMTestVector : PseudoRandom
|
||||
{
|
||||
struct PathSecret
|
||||
{
|
||||
mlspp::NodeIndex node;
|
||||
bytes path_secret;
|
||||
};
|
||||
|
||||
struct LeafPrivateInfo
|
||||
{
|
||||
mlspp::LeafIndex index;
|
||||
mlspp::HPKEPrivateKey encryption_priv;
|
||||
mlspp::SignaturePrivateKey signature_priv;
|
||||
std::vector<PathSecret> path_secrets;
|
||||
};
|
||||
|
||||
struct UpdatePathInfo
|
||||
{
|
||||
mlspp::LeafIndex sender;
|
||||
mlspp::UpdatePath update_path;
|
||||
std::vector<std::optional<bytes>> path_secrets;
|
||||
bytes commit_secret;
|
||||
bytes tree_hash_after;
|
||||
};
|
||||
|
||||
mlspp::CipherSuite cipher_suite;
|
||||
|
||||
bytes group_id;
|
||||
mlspp::epoch_t epoch;
|
||||
bytes confirmed_transcript_hash;
|
||||
|
||||
mlspp::TreeKEMPublicKey ratchet_tree;
|
||||
|
||||
std::vector<LeafPrivateInfo> leaves_private;
|
||||
std::vector<UpdatePathInfo> update_paths;
|
||||
|
||||
TreeKEMTestVector() = default;
|
||||
TreeKEMTestVector(mlspp::CipherSuite suite,
|
||||
TreeStructure tree_structure);
|
||||
std::optional<std::string> verify();
|
||||
};
|
||||
|
||||
struct MessagesTestVector : PseudoRandom
|
||||
{
|
||||
bytes mls_welcome;
|
||||
bytes mls_group_info;
|
||||
bytes mls_key_package;
|
||||
|
||||
bytes ratchet_tree;
|
||||
bytes group_secrets;
|
||||
|
||||
bytes add_proposal;
|
||||
bytes update_proposal;
|
||||
bytes remove_proposal;
|
||||
bytes pre_shared_key_proposal;
|
||||
bytes re_init_proposal;
|
||||
bytes external_init_proposal;
|
||||
bytes group_context_extensions_proposal;
|
||||
|
||||
bytes commit;
|
||||
|
||||
bytes public_message_proposal;
|
||||
bytes public_message_commit;
|
||||
bytes private_message;
|
||||
|
||||
MessagesTestVector();
|
||||
std::optional<std::string> verify() const;
|
||||
};
|
||||
|
||||
struct PassiveClientTestVector : PseudoRandom
|
||||
{
|
||||
struct PSK
|
||||
{
|
||||
bytes psk_id;
|
||||
bytes psk;
|
||||
};
|
||||
|
||||
struct Epoch
|
||||
{
|
||||
std::vector<mlspp::MLSMessage> proposals;
|
||||
mlspp::MLSMessage commit;
|
||||
bytes epoch_authenticator;
|
||||
};
|
||||
|
||||
mlspp::CipherSuite cipher_suite;
|
||||
|
||||
mlspp::MLSMessage key_package;
|
||||
mlspp::SignaturePrivateKey signature_priv;
|
||||
mlspp::HPKEPrivateKey encryption_priv;
|
||||
mlspp::HPKEPrivateKey init_priv;
|
||||
|
||||
std::vector<PSK> external_psks;
|
||||
|
||||
mlspp::MLSMessage welcome;
|
||||
std::optional<mlspp::TreeKEMPublicKey> ratchet_tree;
|
||||
bytes initial_epoch_authenticator;
|
||||
|
||||
std::vector<Epoch> epochs;
|
||||
|
||||
PassiveClientTestVector() = default;
|
||||
std::optional<std::string> verify();
|
||||
};
|
||||
|
||||
} // namespace mls_vectors
|
||||
2052
DPP/mlspp/lib/mls_vectors/src/mls_vectors.cpp
Executable file
2052
DPP/mlspp/lib/mls_vectors/src/mls_vectors.cpp
Executable file
File diff suppressed because it is too large
Load Diff
21
DPP/mlspp/lib/tls_syntax/CMakeLists.txt
Executable file
21
DPP/mlspp/lib/tls_syntax/CMakeLists.txt
Executable file
@@ -0,0 +1,21 @@
|
||||
set(CURRENT_LIB_NAME tls_syntax)
|
||||
|
||||
###
|
||||
### Library Config
|
||||
###
|
||||
|
||||
file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
|
||||
file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
|
||||
|
||||
add_library(${CURRENT_LIB_NAME} STATIC ${LIB_HEADERS} ${LIB_SOURCES})
|
||||
target_include_directories(${CURRENT_LIB_NAME}
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include/${PROJECT_NAME}>
|
||||
)
|
||||
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/bytes/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/hpke/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/mls_vectors/include")
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../lib/tls_syntax/include")
|
||||
61
DPP/mlspp/lib/tls_syntax/include/tls/compat.h
Executable file
61
DPP/mlspp/lib/tls_syntax/include/tls/compat.h
Executable file
@@ -0,0 +1,61 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
|
||||
#ifdef VARIANT_COMPAT
|
||||
#include <variant.hpp>
|
||||
#else
|
||||
#include <variant>
|
||||
#endif // VARIANT_COMPAT
|
||||
|
||||
namespace mlspp::tls {
|
||||
|
||||
namespace var = std;
|
||||
|
||||
// In a similar vein, we provide our own safe accessors for std::optional, since
|
||||
// std::optional::value() is not available on macOS 10.11.
|
||||
namespace opt {
|
||||
|
||||
template<typename T>
|
||||
T&
|
||||
get(std::optional<T>& opt)
|
||||
{
|
||||
if (!opt) {
|
||||
throw std::runtime_error("bad_optional_access");
|
||||
}
|
||||
return *opt;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
const T&
|
||||
get(const std::optional<T>& opt)
|
||||
{
|
||||
if (!opt) {
|
||||
throw std::runtime_error("bad_optional_access");
|
||||
}
|
||||
return *opt;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T&&
|
||||
get(std::optional<T>&& opt)
|
||||
{
|
||||
if (!opt) {
|
||||
throw std::runtime_error("bad_optional_access");
|
||||
}
|
||||
return std::move(*opt);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
const T&&
|
||||
get(const std::optional<T>&& opt)
|
||||
{
|
||||
if (!opt) {
|
||||
throw std::runtime_error("bad_optional_access");
|
||||
}
|
||||
return std::move(*opt);
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mlspp::tls
|
||||
569
DPP/mlspp/lib/tls_syntax/include/tls/tls_syntax.h
Executable file
569
DPP/mlspp/lib/tls_syntax/include/tls/tls_syntax.h
Executable file
@@ -0,0 +1,569 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include <tls/compat.h>
|
||||
|
||||
namespace mlspp::tls {
|
||||
|
||||
// For indicating no min or max in vector definitions
|
||||
const size_t none = std::numeric_limits<size_t>::max();
|
||||
|
||||
class WriteError : public std::invalid_argument
|
||||
{
|
||||
public:
|
||||
using parent = std::invalid_argument;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
class ReadError : public std::invalid_argument
|
||||
{
|
||||
public:
|
||||
using parent = std::invalid_argument;
|
||||
using parent::parent;
|
||||
};
|
||||
|
||||
///
|
||||
/// Declarations of Streams and Traits
|
||||
///
|
||||
|
||||
class ostream
|
||||
{
|
||||
public:
|
||||
static const size_t none = std::numeric_limits<size_t>::max();
|
||||
|
||||
void write_raw(const std::vector<uint8_t>& bytes);
|
||||
|
||||
const std::vector<uint8_t>& bytes() const { return _buffer; }
|
||||
size_t size() const { return _buffer.size(); }
|
||||
bool empty() const { return _buffer.empty(); }
|
||||
|
||||
private:
|
||||
std::vector<uint8_t> _buffer;
|
||||
ostream& write_uint(uint64_t value, int length);
|
||||
|
||||
friend ostream& operator<<(ostream& out, bool data);
|
||||
friend ostream& operator<<(ostream& out, uint8_t data);
|
||||
friend ostream& operator<<(ostream& out, uint16_t data);
|
||||
friend ostream& operator<<(ostream& out, uint32_t data);
|
||||
friend ostream& operator<<(ostream& out, uint64_t data);
|
||||
|
||||
template<typename T>
|
||||
friend ostream& operator<<(ostream& out, const std::vector<T>& data);
|
||||
|
||||
friend struct varint;
|
||||
};
|
||||
|
||||
class istream
|
||||
{
|
||||
public:
|
||||
istream(const std::vector<uint8_t>& data)
|
||||
: _buffer(data)
|
||||
{
|
||||
// So that we can use the constant-time pop_back
|
||||
std::reverse(_buffer.begin(), _buffer.end());
|
||||
}
|
||||
|
||||
size_t size() const { return _buffer.size(); }
|
||||
bool empty() const { return _buffer.empty(); }
|
||||
|
||||
std::vector<uint8_t> bytes()
|
||||
{
|
||||
auto bytes = _buffer;
|
||||
std::reverse(bytes.begin(), bytes.end());
|
||||
return bytes;
|
||||
}
|
||||
|
||||
private:
|
||||
istream() {}
|
||||
std::vector<uint8_t> _buffer;
|
||||
uint8_t next();
|
||||
|
||||
template<typename T>
|
||||
istream& read_uint(T& data, size_t length)
|
||||
{
|
||||
uint64_t value = 0;
|
||||
for (size_t i = 0; i < length; i += 1) {
|
||||
value = (value << unsigned(8)) + next();
|
||||
}
|
||||
data = static_cast<T>(value);
|
||||
return *this;
|
||||
}
|
||||
|
||||
friend istream& operator>>(istream& in, bool& data);
|
||||
friend istream& operator>>(istream& in, uint8_t& data);
|
||||
friend istream& operator>>(istream& in, uint16_t& data);
|
||||
friend istream& operator>>(istream& in, uint32_t& data);
|
||||
friend istream& operator>>(istream& in, uint64_t& data);
|
||||
|
||||
template<typename T>
|
||||
friend istream& operator>>(istream& in, std::vector<T>& data);
|
||||
|
||||
friend struct varint;
|
||||
};
|
||||
|
||||
// Traits must have static encode and decode methods, of the following form:
|
||||
//
|
||||
// static ostream& encode(ostream& str, const T& val);
|
||||
// static istream& decode(istream& str, T& val);
|
||||
//
|
||||
// Trait types will never be constructed; only these static methods are used.
|
||||
// The value arguments to encode and decode can be as strict or as loose as
|
||||
// desired.
|
||||
//
|
||||
// Ultimately, all interesting encoding should be done through traits.
|
||||
//
|
||||
// * vectors
|
||||
// * variants
|
||||
// * varints
|
||||
|
||||
struct pass
|
||||
{
|
||||
template<typename T>
|
||||
static ostream& encode(ostream& str, const T& val);
|
||||
|
||||
template<typename T>
|
||||
static istream& decode(istream& str, T& val);
|
||||
};
|
||||
|
||||
template<typename Ts>
|
||||
struct variant
|
||||
{
|
||||
template<typename... Tp>
|
||||
static inline Ts type(const var::variant<Tp...>& data);
|
||||
|
||||
template<typename... Tp>
|
||||
static ostream& encode(ostream& str, const var::variant<Tp...>& data);
|
||||
|
||||
template<size_t I = 0, typename Te, typename... Tp>
|
||||
static inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
read_variant(istream&, Te, var::variant<Tp...>&);
|
||||
|
||||
template<size_t I = 0, typename Te, typename... Tp>
|
||||
static inline typename std::enable_if <
|
||||
I<sizeof...(Tp), void>::type read_variant(istream& str,
|
||||
Te target_type,
|
||||
var::variant<Tp...>& v);
|
||||
|
||||
template<typename... Tp>
|
||||
static istream& decode(istream& str, var::variant<Tp...>& data);
|
||||
};
|
||||
|
||||
struct varint
|
||||
{
|
||||
static ostream& encode(ostream& str, const uint64_t& val);
|
||||
static istream& decode(istream& str, uint64_t& val);
|
||||
};
|
||||
|
||||
///
|
||||
/// Writer implementations
|
||||
///
|
||||
|
||||
// Primitive writers defined in .cpp file
|
||||
|
||||
// Array writer
|
||||
template<typename T, size_t N>
|
||||
ostream&
|
||||
operator<<(ostream& out, const std::array<T, N>& data)
|
||||
{
|
||||
for (const auto& item : data) {
|
||||
out << item;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// Optional writer
|
||||
template<typename T>
|
||||
ostream&
|
||||
operator<<(ostream& out, const std::optional<T>& opt)
|
||||
{
|
||||
if (!opt) {
|
||||
return out << uint8_t(0);
|
||||
}
|
||||
|
||||
return out << uint8_t(1) << opt::get(opt);
|
||||
}
|
||||
|
||||
// Enum writer
|
||||
template<typename T, std::enable_if_t<std::is_enum<T>::value, int> = 0>
|
||||
ostream&
|
||||
operator<<(ostream& str, const T& val)
|
||||
{
|
||||
auto u = static_cast<std::underlying_type_t<T>>(val);
|
||||
return str << u;
|
||||
}
|
||||
|
||||
// Vector writer
|
||||
template<typename T>
|
||||
ostream&
|
||||
operator<<(ostream& str, const std::vector<T>& vec)
|
||||
{
|
||||
// Pre-encode contents
|
||||
ostream temp;
|
||||
for (const auto& item : vec) {
|
||||
temp << item;
|
||||
}
|
||||
|
||||
// Write the encoded length, then the pre-encoded data
|
||||
varint::encode(str, temp._buffer.size());
|
||||
str.write_raw(temp.bytes());
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
///
|
||||
/// Reader implementations
|
||||
///
|
||||
|
||||
// Primitive type readers defined in .cpp file
|
||||
|
||||
// Array reader
|
||||
template<typename T, size_t N>
|
||||
istream&
|
||||
operator>>(istream& in, std::array<T, N>& data)
|
||||
{
|
||||
for (auto& item : data) {
|
||||
in >> item;
|
||||
}
|
||||
return in;
|
||||
}
|
||||
|
||||
// Optional reader
|
||||
template<typename T>
|
||||
istream&
|
||||
operator>>(istream& in, std::optional<T>& opt)
|
||||
{
|
||||
uint8_t present = 0;
|
||||
in >> present;
|
||||
|
||||
switch (present) {
|
||||
case 0:
|
||||
opt.reset();
|
||||
return in;
|
||||
|
||||
case 1:
|
||||
opt.emplace();
|
||||
return in >> opt::get(opt);
|
||||
|
||||
default:
|
||||
throw std::invalid_argument("Malformed optional");
|
||||
}
|
||||
}
|
||||
|
||||
// Enum reader
|
||||
// XXX(rlb): It would be nice if this could enforce that the values are valid,
|
||||
// but C++ doesn't seem to have that ability. When used as a tag for variants,
|
||||
// the variant reader will enforce, at least.
|
||||
template<typename T, std::enable_if_t<std::is_enum<T>::value, int> = 0>
|
||||
istream&
|
||||
operator>>(istream& str, T& val)
|
||||
{
|
||||
std::underlying_type_t<T> u;
|
||||
str >> u;
|
||||
val = static_cast<T>(u);
|
||||
return str;
|
||||
}
|
||||
|
||||
// Vector reader
|
||||
template<typename T>
|
||||
istream&
|
||||
operator>>(istream& str, std::vector<T>& vec)
|
||||
{
|
||||
// Read the encoded data size
|
||||
auto size = uint64_t(0);
|
||||
varint::decode(str, size);
|
||||
if (size > str._buffer.size()) {
|
||||
throw ReadError("Vector is longer than remaining data");
|
||||
}
|
||||
|
||||
// Read the elements of the vector
|
||||
// NB: Remember that we store the vector in reverse order
|
||||
// NB: This requires that T be default-constructible
|
||||
istream r;
|
||||
r._buffer =
|
||||
std::vector<uint8_t>{ str._buffer.end() - size, str._buffer.end() };
|
||||
|
||||
vec.clear();
|
||||
while (r._buffer.size() > 0) {
|
||||
vec.emplace_back();
|
||||
r >> vec.back();
|
||||
}
|
||||
|
||||
// Truncate the primary buffer
|
||||
str._buffer.erase(str._buffer.end() - size, str._buffer.end());
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
// Abbreviations
|
||||
template<typename T>
|
||||
std::vector<uint8_t>
|
||||
marshal(const T& value)
|
||||
{
|
||||
ostream w;
|
||||
w << value;
|
||||
return w.bytes();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void
|
||||
unmarshal(const std::vector<uint8_t>& data, T& value)
|
||||
{
|
||||
istream r(data);
|
||||
r >> value;
|
||||
}
|
||||
|
||||
template<typename T, typename... Tp>
|
||||
T
|
||||
get(const std::vector<uint8_t>& data, Tp... args)
|
||||
{
|
||||
T value(args...);
|
||||
unmarshal(data, value);
|
||||
return value;
|
||||
}
|
||||
|
||||
// Use this macro to define struct serialization with minimal boilerplate
|
||||
#define TLS_SERIALIZABLE(...) \
|
||||
static const bool _tls_serializable = true; \
|
||||
auto _tls_fields_r() \
|
||||
{ \
|
||||
return std::forward_as_tuple(__VA_ARGS__); \
|
||||
} \
|
||||
auto _tls_fields_w() const \
|
||||
{ \
|
||||
return std::forward_as_tuple(__VA_ARGS__); \
|
||||
}
|
||||
|
||||
// If your struct contains nontrivial members (e.g., vectors), use this to
|
||||
// define traits for them.
|
||||
#define TLS_TRAITS(...) \
|
||||
static const bool _tls_has_traits = true; \
|
||||
using _tls_traits = std::tuple<__VA_ARGS__>;
|
||||
|
||||
template<typename T>
|
||||
struct is_serializable
|
||||
{
|
||||
template<typename U>
|
||||
static std::true_type test(decltype(U::_tls_serializable));
|
||||
|
||||
template<typename U>
|
||||
static std::false_type test(...);
|
||||
|
||||
static const bool value = decltype(test<T>(true))::value;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct has_traits
|
||||
{
|
||||
template<typename U>
|
||||
static std::true_type test(decltype(U::_tls_has_traits));
|
||||
|
||||
template<typename U>
|
||||
static std::false_type test(...);
|
||||
|
||||
static const bool value = decltype(test<T>(true))::value;
|
||||
};
|
||||
|
||||
///
|
||||
/// Trait implementations
|
||||
///
|
||||
|
||||
// Pass-through (normal encoding/decoding)
|
||||
template<typename T>
|
||||
ostream&
|
||||
pass::encode(ostream& str, const T& val)
|
||||
{
|
||||
return str << val;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
istream&
|
||||
pass::decode(istream& str, T& val)
|
||||
{
|
||||
return str >> val;
|
||||
}
|
||||
|
||||
// Variant encoding
|
||||
template<typename Ts, typename Tv>
|
||||
constexpr Ts
|
||||
variant_map();
|
||||
|
||||
#define TLS_VARIANT_MAP(EnumType, MappedType, enum_value) \
|
||||
template<> \
|
||||
constexpr EnumType variant_map<EnumType, MappedType>() \
|
||||
{ \
|
||||
return EnumType::enum_value; \
|
||||
}
|
||||
|
||||
template<typename Ts>
|
||||
template<typename... Tp>
|
||||
inline Ts
|
||||
variant<Ts>::type(const var::variant<Tp...>& data)
|
||||
{
|
||||
const auto get_type = [](const auto& v) {
|
||||
return variant_map<Ts, std::decay_t<decltype(v)>>();
|
||||
};
|
||||
return var::visit(get_type, data);
|
||||
}
|
||||
|
||||
template<typename Ts>
|
||||
template<typename... Tp>
|
||||
ostream&
|
||||
variant<Ts>::encode(ostream& str, const var::variant<Tp...>& data)
|
||||
{
|
||||
const auto write_variant = [&str](auto&& v) {
|
||||
using Tv = std::decay_t<decltype(v)>;
|
||||
str << variant_map<Ts, Tv>() << v;
|
||||
};
|
||||
var::visit(write_variant, data);
|
||||
return str;
|
||||
}
|
||||
|
||||
template<typename Ts>
|
||||
template<size_t I, typename Te, typename... Tp>
|
||||
inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
variant<Ts>::read_variant(istream&, Te, var::variant<Tp...>&)
|
||||
{
|
||||
throw ReadError("Invalid variant type label");
|
||||
}
|
||||
|
||||
template<typename Ts>
|
||||
template<size_t I, typename Te, typename... Tp>
|
||||
inline
|
||||
typename std::enable_if < I<sizeof...(Tp), void>::type
|
||||
variant<Ts>::read_variant(istream& str,
|
||||
Te target_type,
|
||||
var::variant<Tp...>& v)
|
||||
{
|
||||
using Tc = var::variant_alternative_t<I, var::variant<Tp...>>;
|
||||
if (variant_map<Ts, Tc>() == target_type) {
|
||||
str >> v.template emplace<I>();
|
||||
return;
|
||||
}
|
||||
|
||||
read_variant<I + 1>(str, target_type, v);
|
||||
}
|
||||
|
||||
template<typename Ts>
|
||||
template<typename... Tp>
|
||||
istream&
|
||||
variant<Ts>::decode(istream& str, var::variant<Tp...>& data)
|
||||
{
|
||||
Ts target_type;
|
||||
str >> target_type;
|
||||
read_variant(str, target_type, data);
|
||||
return str;
|
||||
}
|
||||
|
||||
// Struct writer without traits (enabled by macro)
|
||||
template<size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
write_tuple(ostream&, const std::tuple<Tp...>&)
|
||||
{
|
||||
}
|
||||
|
||||
template<size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if <
|
||||
I<sizeof...(Tp), void>::type
|
||||
write_tuple(ostream& str, const std::tuple<Tp...>& t)
|
||||
{
|
||||
str << std::get<I>(t);
|
||||
write_tuple<I + 1, Tp...>(str, t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline
|
||||
typename std::enable_if<is_serializable<T>::value && !has_traits<T>::value,
|
||||
ostream&>::type
|
||||
operator<<(ostream& str, const T& obj)
|
||||
{
|
||||
write_tuple(str, obj._tls_fields_w());
|
||||
return str;
|
||||
}
|
||||
|
||||
// Struct writer with traits (enabled by macro)
|
||||
template<typename Tr, size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
write_tuple_traits(ostream&, const std::tuple<Tp...>&)
|
||||
{
|
||||
}
|
||||
|
||||
template<typename Tr, size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if <
|
||||
I<sizeof...(Tp), void>::type
|
||||
write_tuple_traits(ostream& str, const std::tuple<Tp...>& t)
|
||||
{
|
||||
std::tuple_element_t<I, Tr>::encode(str, std::get<I>(t));
|
||||
write_tuple_traits<Tr, I + 1, Tp...>(str, t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline
|
||||
typename std::enable_if<is_serializable<T>::value && has_traits<T>::value,
|
||||
ostream&>::type
|
||||
operator<<(ostream& str, const T& obj)
|
||||
{
|
||||
write_tuple_traits<typename T::_tls_traits>(str, obj._tls_fields_w());
|
||||
return str;
|
||||
}
|
||||
|
||||
// Struct reader without traits (enabled by macro)
|
||||
template<size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
read_tuple(istream&, const std::tuple<Tp...>&)
|
||||
{
|
||||
}
|
||||
|
||||
template<size_t I = 0, typename... Tp>
|
||||
inline
|
||||
typename std::enable_if < I<sizeof...(Tp), void>::type
|
||||
read_tuple(istream& str, const std::tuple<Tp...>& t)
|
||||
{
|
||||
str >> std::get<I>(t);
|
||||
read_tuple<I + 1, Tp...>(str, t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline
|
||||
typename std::enable_if<is_serializable<T>::value && !has_traits<T>::value,
|
||||
istream&>::type
|
||||
operator>>(istream& str, T& obj)
|
||||
{
|
||||
read_tuple(str, obj._tls_fields_r());
|
||||
return str;
|
||||
}
|
||||
|
||||
// Struct reader with traits (enabled by macro)
|
||||
template<typename Tr, size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if<I == sizeof...(Tp), void>::type
|
||||
read_tuple_traits(istream&, const std::tuple<Tp...>&)
|
||||
{
|
||||
}
|
||||
|
||||
template<typename Tr, size_t I = 0, typename... Tp>
|
||||
inline typename std::enable_if <
|
||||
I<sizeof...(Tp), void>::type
|
||||
read_tuple_traits(istream& str, const std::tuple<Tp...>& t)
|
||||
{
|
||||
std::tuple_element_t<I, Tr>::decode(str, std::get<I>(t));
|
||||
read_tuple_traits<Tr, I + 1, Tp...>(str, t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline
|
||||
typename std::enable_if<is_serializable<T>::value && has_traits<T>::value,
|
||||
istream&>::type
|
||||
operator>>(istream& str, T& obj)
|
||||
{
|
||||
read_tuple_traits<typename T::_tls_traits>(str, obj._tls_fields_r());
|
||||
return str;
|
||||
}
|
||||
|
||||
} // namespace mlspp::tls
|
||||
178
DPP/mlspp/lib/tls_syntax/src/tls_syntax.cpp
Executable file
178
DPP/mlspp/lib/tls_syntax/src/tls_syntax.cpp
Executable file
@@ -0,0 +1,178 @@
|
||||
#include <tls/tls_syntax.h>
|
||||
|
||||
// NOLINTNEXTLINE(llvmlibc-implementation-in-namespace)
|
||||
namespace mlspp::tls {
|
||||
|
||||
void
|
||||
ostream::write_raw(const std::vector<uint8_t>& bytes)
|
||||
{
|
||||
// Not sure what the default argument is here
|
||||
_buffer.insert(_buffer.end(), bytes.begin(), bytes.end());
|
||||
}
|
||||
|
||||
// Primitive type writers
|
||||
ostream&
|
||||
ostream::write_uint(uint64_t value, int length)
|
||||
{
|
||||
for (int i = length - 1; i >= 0; --i) {
|
||||
_buffer.push_back(static_cast<uint8_t>(value >> unsigned(8 * i)));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
ostream&
|
||||
operator<<(ostream& out, bool data)
|
||||
{
|
||||
if (data) {
|
||||
return out << uint8_t(1);
|
||||
}
|
||||
|
||||
return out << uint8_t(0);
|
||||
}
|
||||
|
||||
ostream&
|
||||
operator<<(ostream& out, uint8_t data) // NOLINT(llvmlibc-callee-namespace)
|
||||
{
|
||||
return out.write_uint(data, 1);
|
||||
}
|
||||
|
||||
ostream&
|
||||
operator<<(ostream& out, uint16_t data)
|
||||
{
|
||||
return out.write_uint(data, 2);
|
||||
}
|
||||
|
||||
ostream&
|
||||
operator<<(ostream& out, uint32_t data)
|
||||
{
|
||||
return out.write_uint(data, 4);
|
||||
}
|
||||
|
||||
ostream&
|
||||
operator<<(ostream& out, uint64_t data)
|
||||
{
|
||||
return out.write_uint(data, 8);
|
||||
}
|
||||
|
||||
// Because pop_back() on an empty vector is undefined
|
||||
uint8_t
|
||||
istream::next()
|
||||
{
|
||||
if (_buffer.empty()) {
|
||||
throw ReadError("Attempt to read from empty buffer");
|
||||
}
|
||||
|
||||
const uint8_t value = _buffer.back();
|
||||
_buffer.pop_back();
|
||||
return value;
|
||||
}
|
||||
|
||||
// Primitive type readers
|
||||
|
||||
istream&
|
||||
operator>>(istream& in, bool& data)
|
||||
{
|
||||
uint8_t val = 0;
|
||||
in >> val;
|
||||
|
||||
// Linter thinks uint8_t is signed (?)
|
||||
// NOLINTNEXTLINE(hicpp-signed-bitwise)
|
||||
if ((val & 0xFE) != 0) {
|
||||
throw ReadError("Malformed boolean");
|
||||
}
|
||||
|
||||
data = (val == 1);
|
||||
return in;
|
||||
}
|
||||
|
||||
istream&
|
||||
operator>>(istream& in, uint8_t& data) // NOLINT(llvmlibc-callee-namespace)
|
||||
{
|
||||
return in.read_uint(data, 1);
|
||||
}
|
||||
|
||||
istream&
|
||||
operator>>(istream& in, uint16_t& data)
|
||||
{
|
||||
return in.read_uint(data, 2);
|
||||
}
|
||||
|
||||
istream&
|
||||
operator>>(istream& in, uint32_t& data)
|
||||
{
|
||||
return in.read_uint(data, 4);
|
||||
}
|
||||
|
||||
istream&
|
||||
operator>>(istream& in, uint64_t& data)
|
||||
{
|
||||
return in.read_uint(data, 8);
|
||||
}
|
||||
|
||||
// Varint encoding
|
||||
static constexpr size_t VARINT_HEADER_OFFSET = 6;
|
||||
static constexpr uint64_t VARINT_1_HEADER = 0x00; // 0 << V1_OFFSET
|
||||
static constexpr uint64_t VARINT_2_HEADER = 0x4000; // 1 << V2_OFFSET
|
||||
static constexpr uint64_t VARINT_4_HEADER = 0x80000000; // 2 << V4_OFFSET
|
||||
static constexpr uint64_t VARINT_1_MAX = 0x3f;
|
||||
static constexpr uint64_t VARINT_2_MAX = 0x3fff;
|
||||
static constexpr uint64_t VARINT_4_MAX = 0x3fffffff;
|
||||
|
||||
ostream&
|
||||
varint::encode(ostream& str, const uint64_t& val)
|
||||
{
|
||||
if (val <= VARINT_1_MAX) {
|
||||
return str.write_uint(VARINT_1_HEADER | val, 1);
|
||||
}
|
||||
|
||||
if (val <= VARINT_2_MAX) {
|
||||
return str.write_uint(VARINT_2_HEADER | val, 2);
|
||||
}
|
||||
|
||||
if (val <= VARINT_4_MAX) {
|
||||
return str.write_uint(VARINT_4_HEADER | val, 4);
|
||||
}
|
||||
|
||||
throw WriteError("Varint value exceeds maximum size");
|
||||
}
|
||||
|
||||
istream&
|
||||
varint::decode(istream& str, uint64_t& val)
|
||||
{
|
||||
auto log_size = size_t(str._buffer.back() >> VARINT_HEADER_OFFSET);
|
||||
if (log_size > 2) {
|
||||
throw ReadError("Malformed varint header");
|
||||
}
|
||||
|
||||
auto read = uint64_t(0);
|
||||
auto read_bytes = size_t(size_t(1) << log_size);
|
||||
str.read_uint(read, read_bytes);
|
||||
|
||||
switch (log_size) {
|
||||
case 0:
|
||||
read ^= VARINT_1_HEADER;
|
||||
break;
|
||||
|
||||
case 1:
|
||||
read ^= VARINT_2_HEADER;
|
||||
if (read <= VARINT_1_MAX) {
|
||||
throw ReadError("Non-minimal varint");
|
||||
}
|
||||
break;
|
||||
|
||||
case 2:
|
||||
read ^= VARINT_4_HEADER;
|
||||
if (read <= VARINT_2_MAX) {
|
||||
throw ReadError("Non-minimal varint");
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
throw ReadError("Malformed varint header");
|
||||
}
|
||||
|
||||
val = read;
|
||||
return str;
|
||||
}
|
||||
|
||||
} // namespace mlspp::tls
|
||||
13
DPP/mlspp/src/common.cpp
Executable file
13
DPP/mlspp/src/common.cpp
Executable file
@@ -0,0 +1,13 @@
|
||||
#include "mls/common.h"
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
uint64_t
|
||||
seconds_since_epoch()
|
||||
{
|
||||
// TODO(RLB) This should use std::chrono, but that seems not to be available
|
||||
// on some platforms.
|
||||
return std::time(nullptr);
|
||||
}
|
||||
|
||||
} // namespace mlspp
|
||||
443
DPP/mlspp/src/core_types.cpp
Executable file
443
DPP/mlspp/src/core_types.cpp
Executable file
@@ -0,0 +1,443 @@
|
||||
#include "mls/core_types.h"
|
||||
#include "mls/messages.h"
|
||||
|
||||
#include "grease.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
///
|
||||
/// Extensions
|
||||
///
|
||||
|
||||
const Extension::Type RequiredCapabilitiesExtension::type =
|
||||
ExtensionType::required_capabilities;
|
||||
const Extension::Type ApplicationIDExtension::type =
|
||||
ExtensionType::application_id;
|
||||
|
||||
const std::array<uint16_t, 5> default_extensions = {
|
||||
ExtensionType::application_id, ExtensionType::ratchet_tree,
|
||||
ExtensionType::required_capabilities, ExtensionType::external_pub,
|
||||
ExtensionType::external_senders,
|
||||
};
|
||||
|
||||
const std::array<uint16_t, 8> default_proposals = {
|
||||
ProposalType::add,
|
||||
ProposalType::update,
|
||||
ProposalType::remove,
|
||||
ProposalType::psk,
|
||||
ProposalType::reinit,
|
||||
ProposalType::external_init,
|
||||
ProposalType::group_context_extensions,
|
||||
};
|
||||
|
||||
const std::array<ProtocolVersion, 1> all_supported_versions = {
|
||||
ProtocolVersion::mls10
|
||||
};
|
||||
|
||||
const std::array<CipherSuite::ID, 6> all_supported_ciphersuites = {
|
||||
CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519,
|
||||
CipherSuite::ID::P256_AES128GCM_SHA256_P256,
|
||||
CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519,
|
||||
CipherSuite::ID::X448_AES256GCM_SHA512_Ed448,
|
||||
CipherSuite::ID::P521_AES256GCM_SHA512_P521,
|
||||
CipherSuite::ID::X448_CHACHA20POLY1305_SHA512_Ed448,
|
||||
};
|
||||
|
||||
const std::array<CredentialType, 4> all_supported_credentials = {
|
||||
CredentialType::basic,
|
||||
CredentialType::x509,
|
||||
CredentialType::userinfo_vc_draft_00,
|
||||
CredentialType::multi_draft_00
|
||||
};
|
||||
|
||||
Capabilities
|
||||
Capabilities::create_default()
|
||||
{
|
||||
return {
|
||||
{ all_supported_versions.begin(), all_supported_versions.end() },
|
||||
{ all_supported_ciphersuites.begin(), all_supported_ciphersuites.end() },
|
||||
{ /* No non-default extensions */ },
|
||||
{ /* No non-default proposals */ },
|
||||
{ all_supported_credentials.begin(), all_supported_credentials.end() },
|
||||
};
|
||||
}
|
||||
|
||||
bool
|
||||
Capabilities::extensions_supported(
|
||||
const std::vector<Extension::Type>& required) const
|
||||
{
|
||||
return stdx::all_of(required, [&](Extension::Type type) {
|
||||
if (stdx::contains(default_extensions, type)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return stdx::contains(extensions, type);
|
||||
});
|
||||
}
|
||||
|
||||
bool
|
||||
Capabilities::proposals_supported(
|
||||
const std::vector<Proposal::Type>& required) const
|
||||
{
|
||||
return stdx::all_of(required, [&](Proposal::Type type) {
|
||||
if (stdx::contains(default_proposals, type)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return stdx::contains(proposals, type);
|
||||
});
|
||||
}
|
||||
|
||||
bool
|
||||
Capabilities::credential_supported(const Credential& credential) const
|
||||
{
|
||||
return stdx::contains(credentials, credential.type());
|
||||
}
|
||||
|
||||
Lifetime
|
||||
Lifetime::create_default()
|
||||
{
|
||||
return Lifetime{ 0x0000000000000000, 0xffffffffffffffff };
|
||||
}
|
||||
|
||||
void
|
||||
ExtensionList::add(uint16_t type, bytes data)
|
||||
{
|
||||
auto curr = stdx::find_if(
|
||||
extensions, [&](const Extension& ext) -> bool { return ext.type == type; });
|
||||
if (curr != extensions.end()) {
|
||||
curr->data = std::move(data);
|
||||
return;
|
||||
}
|
||||
|
||||
extensions.push_back({ type, std::move(data) });
|
||||
}
|
||||
|
||||
bool
|
||||
ExtensionList::has(uint16_t type) const
|
||||
{
|
||||
return stdx::any_of(extensions,
|
||||
[&](const Extension& ext) { return ext.type == type; });
|
||||
}
|
||||
|
||||
///
|
||||
/// LeafNode
|
||||
///
|
||||
LeafNode::LeafNode(CipherSuite cipher_suite,
|
||||
HPKEPublicKey encryption_key_in,
|
||||
SignaturePublicKey signature_key_in,
|
||||
Credential credential_in,
|
||||
Capabilities capabilities_in,
|
||||
Lifetime lifetime_in,
|
||||
ExtensionList extensions_in,
|
||||
const SignaturePrivateKey& sig_priv)
|
||||
: encryption_key(std::move(encryption_key_in))
|
||||
, signature_key(std::move(signature_key_in))
|
||||
, credential(std::move(credential_in))
|
||||
, capabilities(std::move(capabilities_in))
|
||||
, content(lifetime_in)
|
||||
, extensions(std::move(extensions_in))
|
||||
{
|
||||
grease(extensions);
|
||||
grease(capabilities, extensions);
|
||||
sign(cipher_suite, sig_priv, std::nullopt);
|
||||
}
|
||||
|
||||
void
|
||||
LeafNode::set_capabilities(Capabilities capabilities_in)
|
||||
{
|
||||
capabilities = std::move(capabilities_in);
|
||||
grease(capabilities, extensions);
|
||||
}
|
||||
|
||||
LeafNode
|
||||
LeafNode::for_update(CipherSuite cipher_suite,
|
||||
const bytes& group_id,
|
||||
LeafIndex leaf_index,
|
||||
HPKEPublicKey encryption_key_in,
|
||||
const LeafNodeOptions& opts,
|
||||
const SignaturePrivateKey& sig_priv) const
|
||||
{
|
||||
auto clone = clone_with_options(std::move(encryption_key_in), opts);
|
||||
|
||||
clone.content = Empty{};
|
||||
clone.sign(cipher_suite, sig_priv, { { group_id, leaf_index } });
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
LeafNode
|
||||
LeafNode::for_commit(CipherSuite cipher_suite,
|
||||
const bytes& group_id,
|
||||
LeafIndex leaf_index,
|
||||
HPKEPublicKey encryption_key_in,
|
||||
const bytes& parent_hash,
|
||||
const LeafNodeOptions& opts,
|
||||
const SignaturePrivateKey& sig_priv) const
|
||||
{
|
||||
auto clone = clone_with_options(std::move(encryption_key_in), opts);
|
||||
|
||||
clone.content = ParentHash{ parent_hash };
|
||||
clone.sign(cipher_suite, sig_priv, { { group_id, leaf_index } });
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
LeafNodeSource
|
||||
LeafNode::source() const
|
||||
{
|
||||
return tls::variant<LeafNodeSource>::type(content);
|
||||
}
|
||||
|
||||
void
|
||||
LeafNode::sign(CipherSuite cipher_suite,
|
||||
const SignaturePrivateKey& sig_priv,
|
||||
const std::optional<MemberBinding>& binding)
|
||||
{
|
||||
const auto tbs = to_be_signed(binding);
|
||||
|
||||
if (sig_priv.public_key != signature_key) {
|
||||
throw InvalidParameterError("Signature key mismatch");
|
||||
}
|
||||
|
||||
if (!credential.valid_for(signature_key)) {
|
||||
throw InvalidParameterError("Credential not valid for signature key");
|
||||
}
|
||||
|
||||
signature = sig_priv.sign(cipher_suite, sign_label::leaf_node, tbs);
|
||||
}
|
||||
|
||||
bool
|
||||
LeafNode::verify(CipherSuite cipher_suite,
|
||||
const std::optional<MemberBinding>& binding) const
|
||||
{
|
||||
const auto tbs = to_be_signed(binding);
|
||||
|
||||
if (CredentialType::x509 == credential.type()) {
|
||||
const auto& cred = credential.get<X509Credential>();
|
||||
if (cred.signature_scheme() !=
|
||||
tls_signature_scheme(cipher_suite.sig().id)) {
|
||||
throw std::runtime_error("Signature algorithm invalid");
|
||||
}
|
||||
}
|
||||
|
||||
return signature_key.verify(
|
||||
cipher_suite, sign_label::leaf_node, tbs, signature);
|
||||
}
|
||||
|
||||
bool
|
||||
LeafNode::verify_expiry(uint64_t now) const
|
||||
{
|
||||
const auto valid = overloaded{
|
||||
[now](const Lifetime& lt) {
|
||||
return lt.not_before <= now && now <= lt.not_after;
|
||||
},
|
||||
[](const auto& /* other */) { return false; },
|
||||
};
|
||||
return var::visit(valid, content);
|
||||
}
|
||||
|
||||
bool
|
||||
LeafNode::verify_extension_support(const ExtensionList& ext_list) const
|
||||
{
|
||||
// Verify that extensions in the list are supported
|
||||
auto ext_types = stdx::transform<Extension::Type>(
|
||||
ext_list.extensions, [](const auto& ext) { return ext.type; });
|
||||
|
||||
if (!capabilities.extensions_supported(ext_types)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If there's a RequiredCapabilities extension, verify support
|
||||
const auto maybe_req_capas = ext_list.find<RequiredCapabilitiesExtension>();
|
||||
if (!maybe_req_capas) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto& req_capas = opt::get(maybe_req_capas);
|
||||
return capabilities.extensions_supported(req_capas.extensions) &&
|
||||
capabilities.proposals_supported(req_capas.proposals);
|
||||
}
|
||||
|
||||
LeafNode
|
||||
LeafNode::clone_with_options(HPKEPublicKey encryption_key_in,
|
||||
const LeafNodeOptions& opts) const
|
||||
{
|
||||
auto clone = *this;
|
||||
|
||||
clone.encryption_key = std::move(encryption_key_in);
|
||||
|
||||
if (opts.credential) {
|
||||
clone.credential = opt::get(opts.credential);
|
||||
}
|
||||
|
||||
if (opts.capabilities) {
|
||||
clone.capabilities = opt::get(opts.capabilities);
|
||||
}
|
||||
|
||||
if (opts.extensions) {
|
||||
clone.extensions = opt::get(opts.extensions);
|
||||
}
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
// struct {
|
||||
// HPKEPublicKey encryption_key;
|
||||
// SignaturePublicKey signature_key;
|
||||
// Credential credential;
|
||||
// Capabilities capabilities;
|
||||
//
|
||||
// LeafNodeSource leaf_node_source;
|
||||
// select (leaf_node_source) {
|
||||
// case key_package:
|
||||
// Lifetime lifetime;
|
||||
//
|
||||
// case update:
|
||||
// struct{};
|
||||
//
|
||||
// case commit:
|
||||
// opaque parent_hash<V>;
|
||||
// }
|
||||
//
|
||||
// Extension extensions<V>;
|
||||
//
|
||||
// select (leaf_node_source) {
|
||||
// case key_package:
|
||||
// struct{};
|
||||
//
|
||||
// case update:
|
||||
// opaque group_id<V>;
|
||||
//
|
||||
// case commit:
|
||||
// opaque group_id<V>;
|
||||
// }
|
||||
// } LeafNodeTBS;
|
||||
struct LeafNodeTBS
|
||||
{
|
||||
const HPKEPublicKey& encryption_key;
|
||||
const SignaturePublicKey& signature_key;
|
||||
const Credential& credential;
|
||||
const Capabilities& capabilities;
|
||||
const var::variant<Lifetime, Empty, ParentHash>& content;
|
||||
const ExtensionList& extensions;
|
||||
|
||||
TLS_SERIALIZABLE(encryption_key,
|
||||
signature_key,
|
||||
credential,
|
||||
capabilities,
|
||||
content,
|
||||
extensions)
|
||||
TLS_TRAITS(tls::pass,
|
||||
tls::pass,
|
||||
tls::pass,
|
||||
tls::pass,
|
||||
tls::variant<LeafNodeSource>,
|
||||
tls::pass)
|
||||
};
|
||||
|
||||
bytes
|
||||
LeafNode::to_be_signed(const std::optional<MemberBinding>& binding) const
|
||||
{
|
||||
tls::ostream w;
|
||||
|
||||
w << LeafNodeTBS{
|
||||
encryption_key, signature_key, credential,
|
||||
capabilities, content, extensions,
|
||||
};
|
||||
|
||||
switch (source()) {
|
||||
case LeafNodeSource::key_package:
|
||||
break;
|
||||
|
||||
case LeafNodeSource::update:
|
||||
case LeafNodeSource::commit:
|
||||
w << opt::get(binding);
|
||||
}
|
||||
|
||||
return w.bytes();
|
||||
}
|
||||
|
||||
///
|
||||
/// NodeType, ParentNode, and KeyPackage
|
||||
///
|
||||
|
||||
bytes
|
||||
ParentNode::hash(CipherSuite suite) const
|
||||
{
|
||||
return suite.digest().hash(tls::marshal(this));
|
||||
}
|
||||
|
||||
KeyPackage::KeyPackage()
|
||||
: version(ProtocolVersion::mls10)
|
||||
, cipher_suite(CipherSuite::ID::unknown)
|
||||
{
|
||||
}
|
||||
|
||||
KeyPackage::KeyPackage(CipherSuite suite_in,
|
||||
HPKEPublicKey init_key_in,
|
||||
LeafNode leaf_node_in,
|
||||
ExtensionList extensions_in,
|
||||
const SignaturePrivateKey& sig_priv_in)
|
||||
: version(ProtocolVersion::mls10)
|
||||
, cipher_suite(suite_in)
|
||||
, init_key(std::move(init_key_in))
|
||||
, leaf_node(std::move(leaf_node_in))
|
||||
, extensions(std::move(extensions_in))
|
||||
{
|
||||
grease(extensions);
|
||||
sign(sig_priv_in);
|
||||
}
|
||||
|
||||
KeyPackageRef
|
||||
KeyPackage::ref() const
|
||||
{
|
||||
return cipher_suite.ref(*this);
|
||||
}
|
||||
|
||||
void
|
||||
KeyPackage::sign(const SignaturePrivateKey& sig_priv)
|
||||
{
|
||||
auto tbs = to_be_signed();
|
||||
signature = sig_priv.sign(cipher_suite, sign_label::key_package, tbs);
|
||||
}
|
||||
|
||||
bool
|
||||
KeyPackage::verify() const
|
||||
{
|
||||
// Verify the inner leaf node
|
||||
if (!leaf_node.verify(cipher_suite, std::nullopt)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that the inner leaf node is intended for use in a KeyPackage
|
||||
if (leaf_node.source() != LeafNodeSource::key_package) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify the KeyPackage
|
||||
const auto tbs = to_be_signed();
|
||||
|
||||
if (CredentialType::x509 == leaf_node.credential.type()) {
|
||||
const auto& cred = leaf_node.credential.get<X509Credential>();
|
||||
if (cred.signature_scheme() !=
|
||||
tls_signature_scheme(cipher_suite.sig().id)) {
|
||||
throw std::runtime_error("Signature algorithm invalid");
|
||||
}
|
||||
}
|
||||
|
||||
return leaf_node.signature_key.verify(
|
||||
cipher_suite, sign_label::key_package, tbs, signature);
|
||||
}
|
||||
|
||||
bytes
|
||||
KeyPackage::to_be_signed() const
|
||||
{
|
||||
tls::ostream out;
|
||||
out << version << cipher_suite << init_key << leaf_node << extensions;
|
||||
return out.bytes();
|
||||
}
|
||||
|
||||
} // namespace mlspp
|
||||
298
DPP/mlspp/src/credential.cpp
Executable file
298
DPP/mlspp/src/credential.cpp
Executable file
@@ -0,0 +1,298 @@
|
||||
#include <hpke/certificate.h>
|
||||
#include <hpke/userinfo_vc.h>
|
||||
#include <mls/credential.h>
|
||||
#include <tls/tls_syntax.h>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
///
|
||||
/// X509Credential
|
||||
///
|
||||
|
||||
using mlspp::hpke::Certificate; // NOLINT(misc-unused-using-decls)
|
||||
using mlspp::hpke::Signature; // NOLINT(misc-unused-using-decls)
|
||||
using mlspp::hpke::UserInfoVC; // NOLINT(misc-unused-using-decls)
|
||||
|
||||
static const Signature&
|
||||
find_signature(Signature::ID id)
|
||||
{
|
||||
switch (id) {
|
||||
case Signature::ID::P256_SHA256:
|
||||
return Signature::get<Signature::ID::P256_SHA256>();
|
||||
case Signature::ID::P384_SHA384:
|
||||
return Signature::get<Signature::ID::P384_SHA384>();
|
||||
case Signature::ID::P521_SHA512:
|
||||
return Signature::get<Signature::ID::P521_SHA512>();
|
||||
case Signature::ID::Ed25519:
|
||||
return Signature::get<Signature::ID::Ed25519>();
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
case Signature::ID::Ed448:
|
||||
return Signature::get<Signature::ID::Ed448>();
|
||||
#endif
|
||||
case Signature::ID::RSA_SHA256:
|
||||
return Signature::get<Signature::ID::RSA_SHA256>();
|
||||
default:
|
||||
throw InvalidParameterError("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<X509Credential::CertData>
|
||||
bytes_to_x509_credential_data(const std::vector<bytes>& data_in)
|
||||
{
|
||||
return stdx::transform<X509Credential::CertData>(
|
||||
data_in, [](const bytes& der) { return X509Credential::CertData{ der }; });
|
||||
}
|
||||
|
||||
X509Credential::X509Credential(const std::vector<bytes>& der_chain_in)
|
||||
: der_chain(bytes_to_x509_credential_data(der_chain_in))
|
||||
{
|
||||
if (der_chain.empty()) {
|
||||
throw std::invalid_argument("empty certificate chain");
|
||||
}
|
||||
|
||||
// Parse the chain
|
||||
auto parsed = std::vector<Certificate>();
|
||||
for (const auto& cert : der_chain) {
|
||||
parsed.emplace_back(cert.data);
|
||||
}
|
||||
|
||||
// first element represents leaf cert
|
||||
const auto& sig = find_signature(parsed[0].public_key_algorithm());
|
||||
const auto pub_data = sig.serialize(*parsed[0].public_key);
|
||||
_signature_scheme = tls_signature_scheme(parsed[0].public_key_algorithm());
|
||||
_public_key = SignaturePublicKey{ pub_data };
|
||||
|
||||
// verify chain for valid signatures
|
||||
for (size_t i = 0; i < der_chain.size() - 1; i++) {
|
||||
if (!parsed[i].valid_from(parsed[i + 1])) {
|
||||
throw std::runtime_error("Certificate Chain validation failure");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SignatureScheme
|
||||
X509Credential::signature_scheme() const
|
||||
{
|
||||
return _signature_scheme;
|
||||
}
|
||||
|
||||
SignaturePublicKey
|
||||
X509Credential::public_key() const
|
||||
{
|
||||
return _public_key;
|
||||
}
|
||||
|
||||
bool
|
||||
X509Credential::valid_for(const SignaturePublicKey& pub) const
|
||||
{
|
||||
return pub == public_key();
|
||||
}
|
||||
|
||||
tls::ostream&
|
||||
operator<<(tls::ostream& str, const X509Credential& obj)
|
||||
{
|
||||
return str << obj.der_chain;
|
||||
}
|
||||
|
||||
tls::istream&
|
||||
operator>>(tls::istream& str, X509Credential& obj)
|
||||
{
|
||||
auto der_chain = std::vector<X509Credential::CertData>{};
|
||||
str >> der_chain;
|
||||
|
||||
auto der_in = stdx::transform<bytes>(
|
||||
der_chain, [](const auto& cert_data) { return cert_data.data; });
|
||||
obj = X509Credential(der_in);
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const X509Credential& lhs, const X509Credential& rhs)
|
||||
{
|
||||
return lhs.der_chain == rhs.der_chain;
|
||||
}
|
||||
|
||||
///
|
||||
/// UserInfoVCCredential
|
||||
///
|
||||
UserInfoVCCredential::UserInfoVCCredential(std::string userinfo_vc_jwt_in)
|
||||
: userinfo_vc_jwt(std::move(userinfo_vc_jwt_in))
|
||||
, _vc(std::make_shared<hpke::UserInfoVC>(userinfo_vc_jwt))
|
||||
{
|
||||
}
|
||||
|
||||
bool
|
||||
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
|
||||
UserInfoVCCredential::valid_for(const SignaturePublicKey& pub) const
|
||||
{
|
||||
const auto& vc_pub = _vc->public_key();
|
||||
return pub.data == vc_pub.sig.serialize(*vc_pub.key);
|
||||
}
|
||||
|
||||
bool
|
||||
UserInfoVCCredential::valid_from(const PublicJWK& pub) const
|
||||
{
|
||||
const auto& sig = _vc->signature_algorithm();
|
||||
if (pub.signature_scheme != tls_signature_scheme(sig.id)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto sig_pub = sig.deserialize(pub.public_key.data);
|
||||
return _vc->valid_from(*sig_pub);
|
||||
}
|
||||
|
||||
tls::ostream
|
||||
operator<<(tls::ostream& str, const UserInfoVCCredential& obj)
|
||||
{
|
||||
return str << from_ascii(obj.userinfo_vc_jwt);
|
||||
}
|
||||
|
||||
tls::istream
|
||||
operator>>(tls::istream& str, UserInfoVCCredential& obj)
|
||||
{
|
||||
auto jwt = bytes{};
|
||||
str >> jwt;
|
||||
obj = UserInfoVCCredential(to_ascii(jwt));
|
||||
return str;
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const UserInfoVCCredential& lhs, const UserInfoVCCredential& rhs)
|
||||
{
|
||||
return lhs.userinfo_vc_jwt == rhs.userinfo_vc_jwt;
|
||||
}
|
||||
|
||||
bool
|
||||
operator!=(const UserInfoVCCredential& lhs, const UserInfoVCCredential& rhs)
|
||||
{
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
///
|
||||
/// CredentialBinding and MultiCredential
|
||||
///
|
||||
|
||||
struct CredentialBindingTBS
|
||||
{
|
||||
const CipherSuite& cipher_suite;
|
||||
const Credential& credential;
|
||||
const SignaturePublicKey& credential_key;
|
||||
const SignaturePublicKey& signature_key;
|
||||
|
||||
TLS_SERIALIZABLE(cipher_suite, credential, credential_key, signature_key)
|
||||
};
|
||||
|
||||
CredentialBinding::CredentialBinding(CipherSuite cipher_suite_in,
|
||||
Credential credential_in,
|
||||
const SignaturePrivateKey& credential_priv,
|
||||
const SignaturePublicKey& signature_key)
|
||||
: cipher_suite(cipher_suite_in)
|
||||
, credential(std::move(credential_in))
|
||||
, credential_key(credential_priv.public_key)
|
||||
{
|
||||
if (credential.type() == CredentialType::multi_draft_00) {
|
||||
throw InvalidParameterError("Multi-credentials cannot be nested");
|
||||
}
|
||||
|
||||
if (!credential.valid_for(credential_key)) {
|
||||
throw InvalidParameterError("Credential key does not match credential");
|
||||
}
|
||||
|
||||
signature = credential_priv.sign(
|
||||
cipher_suite, sign_label::multi_credential, to_be_signed(signature_key));
|
||||
}
|
||||
|
||||
bytes
|
||||
CredentialBinding::to_be_signed(const SignaturePublicKey& signature_key) const
|
||||
{
|
||||
return tls::marshal(CredentialBindingTBS{
|
||||
cipher_suite, credential, credential_key, signature_key });
|
||||
}
|
||||
|
||||
bool
|
||||
CredentialBinding::valid_for(const SignaturePublicKey& signature_key) const
|
||||
{
|
||||
auto valid_self = credential.valid_for(credential_key);
|
||||
auto valid_other = credential_key.verify(cipher_suite,
|
||||
sign_label::multi_credential,
|
||||
to_be_signed(signature_key),
|
||||
signature);
|
||||
|
||||
return valid_self && valid_other;
|
||||
}
|
||||
|
||||
MultiCredential::MultiCredential(
|
||||
const std::vector<CredentialBindingInput>& binding_inputs,
|
||||
const SignaturePublicKey& signature_key)
|
||||
{
|
||||
bindings =
|
||||
stdx::transform<CredentialBinding>(binding_inputs, [&](auto&& input) {
|
||||
return CredentialBinding(input.cipher_suite,
|
||||
input.credential,
|
||||
input.credential_priv,
|
||||
signature_key);
|
||||
});
|
||||
}
|
||||
|
||||
bool
|
||||
MultiCredential::valid_for(const SignaturePublicKey& pub) const
|
||||
{
|
||||
return stdx::all_of(
|
||||
bindings, [&](const auto& binding) { return binding.valid_for(pub); });
|
||||
}
|
||||
|
||||
///
|
||||
/// Credential
|
||||
///
|
||||
|
||||
CredentialType
|
||||
Credential::type() const
|
||||
{
|
||||
return tls::variant<CredentialType>::type(_cred);
|
||||
}
|
||||
|
||||
Credential
|
||||
Credential::basic(const bytes& identity)
|
||||
{
|
||||
return { BasicCredential{ identity } };
|
||||
}
|
||||
|
||||
Credential
|
||||
Credential::x509(const std::vector<bytes>& der_chain)
|
||||
{
|
||||
return { X509Credential{ der_chain } };
|
||||
}
|
||||
|
||||
Credential
|
||||
Credential::multi(const std::vector<CredentialBindingInput>& binding_inputs,
|
||||
const SignaturePublicKey& signature_key)
|
||||
{
|
||||
return { MultiCredential{ binding_inputs, signature_key } };
|
||||
}
|
||||
|
||||
Credential
|
||||
Credential::userinfo_vc(const std::string& userinfo_vc_jwt)
|
||||
{
|
||||
return { UserInfoVCCredential{ userinfo_vc_jwt } };
|
||||
}
|
||||
|
||||
bool
|
||||
Credential::valid_for(const SignaturePublicKey& pub) const
|
||||
{
|
||||
const auto pub_key_match = overloaded{
|
||||
[&](const X509Credential& x509) { return x509.valid_for(pub); },
|
||||
[](const BasicCredential& /* basic */) { return true; },
|
||||
[&](const UserInfoVCCredential& vc) { return vc.valid_for(pub); },
|
||||
[&](const MultiCredential& multi) { return multi.valid_for(pub); },
|
||||
};
|
||||
|
||||
return var::visit(pub_key_match, _cred);
|
||||
}
|
||||
|
||||
Credential::Credential(SpecificCredential specific)
|
||||
: _cred(std::move(specific))
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace mlspp
|
||||
498
DPP/mlspp/src/crypto.cpp
Executable file
498
DPP/mlspp/src/crypto.cpp
Executable file
@@ -0,0 +1,498 @@
|
||||
#include <mls/core_types.h>
|
||||
#include <mls/crypto.h>
|
||||
#include <mls/messages.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
using mlspp::hpke::AEAD; // NOLINT(misc-unused-using-decls)
|
||||
using mlspp::hpke::Digest; // NOLINT(misc-unused-using-decls)
|
||||
using mlspp::hpke::HPKE; // NOLINT(misc-unused-using-decls)
|
||||
using mlspp::hpke::KDF; // NOLINT(misc-unused-using-decls)
|
||||
using mlspp::hpke::KEM; // NOLINT(misc-unused-using-decls)
|
||||
using mlspp::hpke::Signature; // NOLINT(misc-unused-using-decls)
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
SignatureScheme
|
||||
tls_signature_scheme(Signature::ID id)
|
||||
{
|
||||
switch (id) {
|
||||
case Signature::ID::P256_SHA256:
|
||||
return SignatureScheme::ecdsa_secp256r1_sha256;
|
||||
case Signature::ID::P384_SHA384:
|
||||
return SignatureScheme::ecdsa_secp384r1_sha384;
|
||||
case Signature::ID::P521_SHA512:
|
||||
return SignatureScheme::ecdsa_secp521r1_sha512;
|
||||
case Signature::ID::Ed25519:
|
||||
return SignatureScheme::ed25519;
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
case Signature::ID::Ed448:
|
||||
return SignatureScheme::ed448;
|
||||
#endif
|
||||
case Signature::ID::RSA_SHA256:
|
||||
return SignatureScheme::rsa_pkcs1_sha256;
|
||||
default:
|
||||
throw InvalidParameterError("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// CipherSuites and details
|
||||
///
|
||||
|
||||
CipherSuite::CipherSuite()
|
||||
: id(ID::unknown)
|
||||
{
|
||||
}
|
||||
|
||||
CipherSuite::CipherSuite(ID id_in)
|
||||
: id(id_in)
|
||||
{
|
||||
}
|
||||
|
||||
SignatureScheme
|
||||
CipherSuite::signature_scheme() const
|
||||
{
|
||||
switch (id) {
|
||||
case ID::X25519_AES128GCM_SHA256_Ed25519:
|
||||
case ID::X25519_CHACHA20POLY1305_SHA256_Ed25519:
|
||||
return SignatureScheme::ed25519;
|
||||
case ID::P256_AES128GCM_SHA256_P256:
|
||||
return SignatureScheme::ecdsa_secp256r1_sha256;
|
||||
case ID::X448_AES256GCM_SHA512_Ed448:
|
||||
case ID::X448_CHACHA20POLY1305_SHA512_Ed448:
|
||||
return SignatureScheme::ed448;
|
||||
case ID::P521_AES256GCM_SHA512_P521:
|
||||
return SignatureScheme::ecdsa_secp521r1_sha512;
|
||||
case ID::P384_AES256GCM_SHA384_P384:
|
||||
return SignatureScheme::ecdsa_secp384r1_sha384;
|
||||
default:
|
||||
throw InvalidParameterError("Unsupported algorithm");
|
||||
}
|
||||
}
|
||||
|
||||
const CipherSuite::Ciphers&
|
||||
CipherSuite::get() const
|
||||
{
|
||||
static const auto ciphers_X25519_AES128GCM_SHA256_Ed25519 =
|
||||
CipherSuite::Ciphers{
|
||||
HPKE(KEM::ID::DHKEM_X25519_SHA256,
|
||||
KDF::ID::HKDF_SHA256,
|
||||
AEAD::ID::AES_128_GCM),
|
||||
Digest::get<Digest::ID::SHA256>(),
|
||||
Signature::get<Signature::ID::Ed25519>(),
|
||||
};
|
||||
|
||||
static const auto ciphers_P256_AES128GCM_SHA256_P256 = CipherSuite::Ciphers{
|
||||
HPKE(
|
||||
KEM::ID::DHKEM_P256_SHA256, KDF::ID::HKDF_SHA256, AEAD::ID::AES_128_GCM),
|
||||
Digest::get<Digest::ID::SHA256>(),
|
||||
Signature::get<Signature::ID::P256_SHA256>(),
|
||||
};
|
||||
|
||||
static const auto ciphers_X25519_CHACHA20POLY1305_SHA256_Ed25519 =
|
||||
CipherSuite::Ciphers{
|
||||
HPKE(KEM::ID::DHKEM_X25519_SHA256,
|
||||
KDF::ID::HKDF_SHA256,
|
||||
AEAD::ID::CHACHA20_POLY1305),
|
||||
Digest::get<Digest::ID::SHA256>(),
|
||||
Signature::get<Signature::ID::Ed25519>(),
|
||||
};
|
||||
|
||||
static const auto ciphers_P521_AES256GCM_SHA512_P521 = CipherSuite::Ciphers{
|
||||
HPKE(
|
||||
KEM::ID::DHKEM_P521_SHA512, KDF::ID::HKDF_SHA512, AEAD::ID::AES_256_GCM),
|
||||
Digest::get<Digest::ID::SHA512>(),
|
||||
Signature::get<Signature::ID::P521_SHA512>(),
|
||||
};
|
||||
|
||||
static const auto ciphers_P384_AES256GCM_SHA384_P384 = CipherSuite::Ciphers{
|
||||
HPKE(
|
||||
KEM::ID::DHKEM_P384_SHA384, KDF::ID::HKDF_SHA384, AEAD::ID::AES_256_GCM),
|
||||
Digest::get<Digest::ID::SHA384>(),
|
||||
Signature::get<Signature::ID::P384_SHA384>(),
|
||||
};
|
||||
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
static const auto ciphers_X448_AES256GCM_SHA512_Ed448 = CipherSuite::Ciphers{
|
||||
HPKE(
|
||||
KEM::ID::DHKEM_X448_SHA512, KDF::ID::HKDF_SHA512, AEAD::ID::AES_256_GCM),
|
||||
Digest::get<Digest::ID::SHA512>(),
|
||||
Signature::get<Signature::ID::Ed448>(),
|
||||
};
|
||||
|
||||
static const auto ciphers_X448_CHACHA20POLY1305_SHA512_Ed448 =
|
||||
CipherSuite::Ciphers{
|
||||
HPKE(KEM::ID::DHKEM_X448_SHA512,
|
||||
KDF::ID::HKDF_SHA512,
|
||||
AEAD::ID::CHACHA20_POLY1305),
|
||||
Digest::get<Digest::ID::SHA512>(),
|
||||
Signature::get<Signature::ID::Ed448>(),
|
||||
};
|
||||
#endif
|
||||
|
||||
switch (id) {
|
||||
case ID::unknown:
|
||||
throw InvalidParameterError("Uninitialized ciphersuite");
|
||||
|
||||
case ID::X25519_AES128GCM_SHA256_Ed25519:
|
||||
return ciphers_X25519_AES128GCM_SHA256_Ed25519;
|
||||
|
||||
case ID::P256_AES128GCM_SHA256_P256:
|
||||
return ciphers_P256_AES128GCM_SHA256_P256;
|
||||
|
||||
case ID::X25519_CHACHA20POLY1305_SHA256_Ed25519:
|
||||
return ciphers_X25519_CHACHA20POLY1305_SHA256_Ed25519;
|
||||
|
||||
case ID::P521_AES256GCM_SHA512_P521:
|
||||
return ciphers_P521_AES256GCM_SHA512_P521;
|
||||
|
||||
case ID::P384_AES256GCM_SHA384_P384:
|
||||
return ciphers_P384_AES256GCM_SHA384_P384;
|
||||
|
||||
#if !defined(WITH_BORINGSSL)
|
||||
case ID::X448_AES256GCM_SHA512_Ed448:
|
||||
return ciphers_X448_AES256GCM_SHA512_Ed448;
|
||||
|
||||
case ID::X448_CHACHA20POLY1305_SHA512_Ed448:
|
||||
return ciphers_X448_CHACHA20POLY1305_SHA512_Ed448;
|
||||
#endif
|
||||
|
||||
default:
|
||||
throw InvalidParameterError("Unsupported ciphersuite");
|
||||
}
|
||||
}
|
||||
|
||||
struct HKDFLabel
|
||||
{
|
||||
uint16_t length;
|
||||
bytes label;
|
||||
bytes context;
|
||||
|
||||
TLS_SERIALIZABLE(length, label, context)
|
||||
};
|
||||
|
||||
bytes
|
||||
CipherSuite::expand_with_label(const bytes& secret,
|
||||
const std::string& label,
|
||||
const bytes& context,
|
||||
size_t length) const
|
||||
{
|
||||
auto mls_label = from_ascii(std::string("MLS 1.0 ") + label);
|
||||
auto length16 = static_cast<uint16_t>(length);
|
||||
auto label_bytes = tls::marshal(HKDFLabel{ length16, mls_label, context });
|
||||
return get().hpke.kdf.expand(secret, label_bytes, length);
|
||||
}
|
||||
|
||||
bytes
|
||||
CipherSuite::derive_secret(const bytes& secret, const std::string& label) const
|
||||
{
|
||||
return expand_with_label(secret, label, {}, secret_size());
|
||||
}
|
||||
|
||||
bytes
|
||||
CipherSuite::derive_tree_secret(const bytes& secret,
|
||||
const std::string& label,
|
||||
uint32_t generation,
|
||||
size_t length) const
|
||||
{
|
||||
return expand_with_label(secret, label, tls::marshal(generation), length);
|
||||
}
|
||||
|
||||
#if WITH_BORINGSSL
|
||||
const std::array<CipherSuite::ID, 5> all_supported_suites = {
|
||||
CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519,
|
||||
CipherSuite::ID::P256_AES128GCM_SHA256_P256,
|
||||
CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519,
|
||||
CipherSuite::ID::P521_AES256GCM_SHA512_P521,
|
||||
CipherSuite::ID::P384_AES256GCM_SHA384_P384,
|
||||
};
|
||||
#else
|
||||
const std::array<CipherSuite::ID, 7> all_supported_suites = {
|
||||
CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519,
|
||||
CipherSuite::ID::P256_AES128GCM_SHA256_P256,
|
||||
CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519,
|
||||
CipherSuite::ID::P521_AES256GCM_SHA512_P521,
|
||||
CipherSuite::ID::P384_AES256GCM_SHA384_P384,
|
||||
CipherSuite::ID::X448_CHACHA20POLY1305_SHA512_Ed448,
|
||||
CipherSuite::ID::X448_AES256GCM_SHA512_Ed448,
|
||||
};
|
||||
#endif
|
||||
|
||||
// MakeKeyPackageRef(value) = KDF.expand(
|
||||
// KDF.extract("", value), "MLS 1.0 KeyPackage Reference", 16)
|
||||
template<>
|
||||
const bytes&
|
||||
CipherSuite::reference_label<KeyPackage>()
|
||||
{
|
||||
static const auto label = from_ascii("MLS 1.0 KeyPackage Reference");
|
||||
return label;
|
||||
}
|
||||
|
||||
// MakeProposalRef(value) = KDF.expand(
|
||||
// KDF.extract("", value), "MLS 1.0 Proposal Reference", 16)
|
||||
//
|
||||
// Even though the label says "Proposal", we actually hash the entire enclosing
|
||||
// AuthenticatedContent object.
|
||||
template<>
|
||||
const bytes&
|
||||
CipherSuite::reference_label<AuthenticatedContent>()
|
||||
{
|
||||
static const auto label = from_ascii("MLS 1.0 Proposal Reference");
|
||||
return label;
|
||||
}
|
||||
|
||||
///
|
||||
/// HPKEPublicKey and HPKEPrivateKey
|
||||
///
|
||||
|
||||
// This function produces a non-literal type, so it can't be constexpr.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define MLS_1_0_PLUS(label) from_ascii("MLS 1.0 " label)
|
||||
|
||||
static bytes
|
||||
mls_1_0_plus(const std::string& label)
|
||||
{
|
||||
auto plus = "MLS 1.0 "s + label;
|
||||
return from_ascii(plus);
|
||||
}
|
||||
|
||||
namespace encrypt_label {
|
||||
const std::string update_path_node = "UpdatePathNode";
|
||||
const std::string welcome = "Welcome";
|
||||
} // namespace encrypt_label
|
||||
|
||||
struct EncryptContext
|
||||
{
|
||||
const bytes& label;
|
||||
const bytes& content;
|
||||
TLS_SERIALIZABLE(label, content)
|
||||
};
|
||||
|
||||
HPKECiphertext
|
||||
HPKEPublicKey::encrypt(CipherSuite suite,
|
||||
const std::string& label,
|
||||
const bytes& context,
|
||||
const bytes& pt) const
|
||||
{
|
||||
auto label_plus = mls_1_0_plus(label);
|
||||
auto encrypt_context = tls::marshal(EncryptContext{ label_plus, context });
|
||||
auto pkR = suite.hpke().kem.deserialize(data);
|
||||
auto [enc, ctx] = suite.hpke().setup_base_s(*pkR, encrypt_context);
|
||||
auto ct = ctx.seal({}, pt);
|
||||
return HPKECiphertext{ enc, ct };
|
||||
}
|
||||
|
||||
std::tuple<bytes, bytes>
|
||||
HPKEPublicKey::do_export(CipherSuite suite,
|
||||
const bytes& info,
|
||||
const std::string& label,
|
||||
size_t size) const
|
||||
{
|
||||
auto label_data = from_ascii(label);
|
||||
auto pkR = suite.hpke().kem.deserialize(data);
|
||||
auto [enc, ctx] = suite.hpke().setup_base_s(*pkR, info);
|
||||
auto exported = ctx.do_export(label_data, size);
|
||||
return std::make_tuple(enc, exported);
|
||||
}
|
||||
|
||||
HPKEPrivateKey
|
||||
HPKEPrivateKey::generate(CipherSuite suite)
|
||||
{
|
||||
auto priv = suite.hpke().kem.generate_key_pair();
|
||||
auto priv_data = suite.hpke().kem.serialize_private(*priv);
|
||||
auto pub = priv->public_key();
|
||||
auto pub_data = suite.hpke().kem.serialize(*pub);
|
||||
return { priv_data, pub_data };
|
||||
}
|
||||
|
||||
HPKEPrivateKey
|
||||
HPKEPrivateKey::parse(CipherSuite suite, const bytes& data)
|
||||
{
|
||||
auto priv = suite.hpke().kem.deserialize_private(data);
|
||||
auto pub = priv->public_key();
|
||||
auto pub_data = suite.hpke().kem.serialize(*pub);
|
||||
return { data, pub_data };
|
||||
}
|
||||
|
||||
HPKEPrivateKey
|
||||
HPKEPrivateKey::derive(CipherSuite suite, const bytes& secret)
|
||||
{
|
||||
auto priv = suite.hpke().kem.derive_key_pair(secret);
|
||||
auto priv_data = suite.hpke().kem.serialize_private(*priv);
|
||||
auto pub = priv->public_key();
|
||||
auto pub_data = suite.hpke().kem.serialize(*pub);
|
||||
return { priv_data, pub_data };
|
||||
}
|
||||
|
||||
bytes
|
||||
HPKEPrivateKey::decrypt(CipherSuite suite,
|
||||
const std::string& label,
|
||||
const bytes& context,
|
||||
const HPKECiphertext& ct) const
|
||||
{
|
||||
auto label_plus = mls_1_0_plus(label);
|
||||
auto encrypt_context = tls::marshal(EncryptContext{ label_plus, context });
|
||||
auto skR = suite.hpke().kem.deserialize_private(data);
|
||||
auto ctx = suite.hpke().setup_base_r(ct.kem_output, *skR, encrypt_context);
|
||||
auto pt = ctx.open({}, ct.ciphertext);
|
||||
if (!pt) {
|
||||
throw InvalidParameterError("HPKE decryption failure");
|
||||
}
|
||||
|
||||
return opt::get(pt);
|
||||
}
|
||||
|
||||
bytes
|
||||
HPKEPrivateKey::do_export(CipherSuite suite,
|
||||
const bytes& info,
|
||||
const bytes& kem_output,
|
||||
const std::string& label,
|
||||
size_t size) const
|
||||
{
|
||||
auto label_data = from_ascii(label);
|
||||
auto skR = suite.hpke().kem.deserialize_private(data);
|
||||
auto ctx = suite.hpke().setup_base_r(kem_output, *skR, info);
|
||||
return ctx.do_export(label_data, size);
|
||||
}
|
||||
|
||||
HPKEPrivateKey::HPKEPrivateKey(bytes priv_data, bytes pub_data)
|
||||
: data(std::move(priv_data))
|
||||
, public_key{ std::move(pub_data) }
|
||||
{
|
||||
}
|
||||
|
||||
void
|
||||
HPKEPrivateKey::set_public_key(CipherSuite suite)
|
||||
{
|
||||
const auto priv = suite.hpke().kem.deserialize_private(data);
|
||||
auto pub = priv->public_key();
|
||||
public_key.data = suite.hpke().kem.serialize(*pub);
|
||||
}
|
||||
|
||||
///
|
||||
/// SignaturePublicKey and SignaturePrivateKey
|
||||
///
|
||||
namespace sign_label {
|
||||
const std::string mls_content = "FramedContentTBS";
|
||||
const std::string leaf_node = "LeafNodeTBS";
|
||||
const std::string key_package = "KeyPackageTBS";
|
||||
const std::string group_info = "GroupInfoTBS";
|
||||
const std::string multi_credential = "MultiCredential";
|
||||
} // namespace sign_label
|
||||
|
||||
struct SignContent
|
||||
{
|
||||
const bytes& label;
|
||||
const bytes& content;
|
||||
TLS_SERIALIZABLE(label, content)
|
||||
};
|
||||
|
||||
bool
|
||||
SignaturePublicKey::verify(const CipherSuite& suite,
|
||||
const std::string& label,
|
||||
const bytes& message,
|
||||
const bytes& signature) const
|
||||
{
|
||||
auto label_plus = mls_1_0_plus(label);
|
||||
const auto content = tls::marshal(SignContent{ label_plus, message });
|
||||
auto pub = suite.sig().deserialize(data);
|
||||
return suite.sig().verify(content, signature, *pub);
|
||||
}
|
||||
|
||||
SignaturePublicKey
|
||||
SignaturePublicKey::from_jwk(CipherSuite suite, const std::string& json_str)
|
||||
{
|
||||
auto pub = suite.sig().import_jwk(json_str);
|
||||
auto pub_data = suite.sig().serialize(*pub);
|
||||
return SignaturePublicKey{ pub_data };
|
||||
}
|
||||
|
||||
std::string
|
||||
SignaturePublicKey::to_jwk(CipherSuite suite) const
|
||||
{
|
||||
auto pub = suite.sig().deserialize(data);
|
||||
return suite.sig().export_jwk(*pub);
|
||||
}
|
||||
|
||||
PublicJWK
|
||||
PublicJWK::parse(const std::string& jwk_json)
|
||||
{
|
||||
const auto parsed = Signature::parse_jwk(jwk_json);
|
||||
const auto scheme = tls_signature_scheme(parsed.sig.id);
|
||||
const auto pub_data = parsed.sig.serialize(*parsed.key);
|
||||
return { scheme, parsed.key_id, { pub_data } };
|
||||
}
|
||||
|
||||
SignaturePrivateKey
|
||||
SignaturePrivateKey::generate(CipherSuite suite)
|
||||
{
|
||||
auto priv = suite.sig().generate_key_pair();
|
||||
auto priv_data = suite.sig().serialize_private(*priv);
|
||||
auto pub = priv->public_key();
|
||||
auto pub_data = suite.sig().serialize(*pub);
|
||||
return { priv_data, pub_data };
|
||||
}
|
||||
|
||||
SignaturePrivateKey
|
||||
SignaturePrivateKey::parse(CipherSuite suite, const bytes& data)
|
||||
{
|
||||
auto priv = suite.sig().deserialize_private(data);
|
||||
auto pub = priv->public_key();
|
||||
auto pub_data = suite.sig().serialize(*pub);
|
||||
return { data, pub_data };
|
||||
}
|
||||
|
||||
SignaturePrivateKey
|
||||
SignaturePrivateKey::derive(CipherSuite suite, const bytes& secret)
|
||||
{
|
||||
auto priv = suite.sig().derive_key_pair(secret);
|
||||
auto priv_data = suite.sig().serialize_private(*priv);
|
||||
auto pub = priv->public_key();
|
||||
auto pub_data = suite.sig().serialize(*pub);
|
||||
return { priv_data, pub_data };
|
||||
}
|
||||
|
||||
bytes
|
||||
SignaturePrivateKey::sign(const CipherSuite& suite,
|
||||
const std::string& label,
|
||||
const bytes& message) const
|
||||
{
|
||||
auto label_plus = mls_1_0_plus(label);
|
||||
const auto content = tls::marshal(SignContent{ label_plus, message });
|
||||
const auto priv = suite.sig().deserialize_private(data);
|
||||
return suite.sig().sign(content, *priv);
|
||||
}
|
||||
|
||||
SignaturePrivateKey::SignaturePrivateKey(bytes priv_data, bytes pub_data)
|
||||
: data(std::move(priv_data))
|
||||
, public_key{ std::move(pub_data) }
|
||||
{
|
||||
}
|
||||
|
||||
void
|
||||
SignaturePrivateKey::set_public_key(CipherSuite suite)
|
||||
{
|
||||
const auto priv = suite.sig().deserialize_private(data);
|
||||
auto pub = priv->public_key();
|
||||
public_key.data = suite.sig().serialize(*pub);
|
||||
}
|
||||
|
||||
SignaturePrivateKey
|
||||
SignaturePrivateKey::from_jwk(CipherSuite suite, const std::string& json_str)
|
||||
{
|
||||
auto priv = suite.sig().import_jwk_private(json_str);
|
||||
auto priv_data = suite.sig().serialize_private(*priv);
|
||||
auto pub = priv->public_key();
|
||||
auto pub_data = suite.sig().serialize(*pub);
|
||||
return { priv_data, pub_data };
|
||||
}
|
||||
|
||||
std::string
|
||||
SignaturePrivateKey::to_jwk(CipherSuite suite) const
|
||||
{
|
||||
const auto priv = suite.sig().deserialize_private(data);
|
||||
return suite.sig().export_jwk_private(*priv);
|
||||
}
|
||||
|
||||
} // namespace mlspp
|
||||
126
DPP/mlspp/src/grease.cpp
Executable file
126
DPP/mlspp/src/grease.cpp
Executable file
@@ -0,0 +1,126 @@
|
||||
#include "grease.h"
|
||||
|
||||
#include <random>
|
||||
#include <set>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
#ifdef DISABLE_GREASE
|
||||
|
||||
void
|
||||
grease([[maybe_unused]] Capabilities& capabilities,
|
||||
[[maybe_unused]] const ExtensionList& extensions)
|
||||
{
|
||||
}
|
||||
|
||||
void
|
||||
grease([[maybe_unused]] ExtensionList& extensions)
|
||||
{
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// Randomness parmeters:
|
||||
// * Given a list of N items, insert max(1, rand(p_grease * N)) GREASE values
|
||||
// * Each GREASE value added is distinct, unless more than 15 values are needed
|
||||
// * For extensions, each GREASE extension has rand(n_grease_ext) random bytes
|
||||
// of data
|
||||
const size_t log_p_grease = 1; // -log2(p_grease) => p_grease = 1/2
|
||||
const size_t max_grease_ext_size = 16;
|
||||
|
||||
const std::array<uint16_t, 15> grease_values = { 0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A,
|
||||
0x4A4A, 0x5A5A, 0x6A6A, 0x7A7A,
|
||||
0x8A8A, 0x9A9A, 0xAAAA, 0xBABA,
|
||||
0xCACA, 0xDADA, 0xEAEA };
|
||||
|
||||
static size_t
|
||||
rand_int(size_t n)
|
||||
{
|
||||
static auto seed = std::random_device()();
|
||||
static auto rng = std::mt19937(seed);
|
||||
return std::uniform_int_distribution<size_t>(0, n)(rng);
|
||||
}
|
||||
|
||||
static uint16_t
|
||||
grease_value()
|
||||
{
|
||||
const auto where = rand_int(grease_values.size() - 1);
|
||||
return grease_values.at(where);
|
||||
}
|
||||
|
||||
static bool
|
||||
grease_value(uint16_t val)
|
||||
{
|
||||
static constexpr auto grease_mask = uint16_t(0x0F0F);
|
||||
return ((val & grease_mask) == 0x0A0A) && val != 0xFAFA;
|
||||
}
|
||||
|
||||
static std::set<uint16_t>
|
||||
grease_sample(size_t count)
|
||||
{
|
||||
auto vals = std::set<uint16_t>{};
|
||||
|
||||
while (vals.size() < count) {
|
||||
uint16_t val = grease_value();
|
||||
while (vals.count(val) > 0 && vals.size() < grease_values.size()) {
|
||||
val = grease_value();
|
||||
}
|
||||
|
||||
vals.insert(val);
|
||||
}
|
||||
|
||||
return vals;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void
|
||||
grease(std::vector<T>& vec)
|
||||
{
|
||||
const auto count = std::max(size_t(1), rand_int(vec.size() >> log_p_grease));
|
||||
for (const auto val : grease_sample(count)) {
|
||||
const auto where = static_cast<ptrdiff_t>(rand_int(vec.size()));
|
||||
vec.insert(std::begin(vec) + where, static_cast<T>(val));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
grease(Capabilities& capabilities, const ExtensionList& extensions)
|
||||
{
|
||||
// Add GREASE to the appropriate portions of the capabilities
|
||||
grease(capabilities.cipher_suites);
|
||||
grease(capabilities.extensions);
|
||||
grease(capabilities.proposals);
|
||||
grease(capabilities.credentials);
|
||||
|
||||
// Ensure that the GREASE extensions are reflected in Capabilities.extensions
|
||||
for (const auto& ext : extensions.extensions) {
|
||||
if (!grease_value(ext.type)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (stdx::contains(capabilities.extensions, ext.type)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto where =
|
||||
static_cast<ptrdiff_t>(rand_int(capabilities.extensions.size()));
|
||||
const auto where_ptr = std::begin(capabilities.extensions) + where;
|
||||
capabilities.extensions.insert(where_ptr, ext.type);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
grease(ExtensionList& extensions)
|
||||
{
|
||||
auto& ext = extensions.extensions;
|
||||
const auto count = std::max(size_t(1), rand_int(ext.size() >> log_p_grease));
|
||||
for (const auto ext_type : grease_sample(count)) {
|
||||
const auto where = static_cast<ptrdiff_t>(rand_int(ext.size()));
|
||||
auto ext_data = random_bytes(rand_int(max_grease_ext_size));
|
||||
ext.insert(std::begin(ext) + where, { ext_type, std::move(ext_data) });
|
||||
}
|
||||
}
|
||||
|
||||
#endif // DISABLE_GREASE
|
||||
|
||||
} // namespace mlspp
|
||||
13
DPP/mlspp/src/grease.h
Executable file
13
DPP/mlspp/src/grease.h
Executable file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mls/core_types.h"
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
void
|
||||
grease(Capabilities& capabilities, const ExtensionList& extensions);
|
||||
|
||||
void
|
||||
grease(ExtensionList& extensions);
|
||||
|
||||
} // namespace mlspp
|
||||
579
DPP/mlspp/src/key_schedule.cpp
Executable file
579
DPP/mlspp/src/key_schedule.cpp
Executable file
@@ -0,0 +1,579 @@
|
||||
#include <mls/key_schedule.h>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
///
|
||||
/// Key Derivation Functions
|
||||
///
|
||||
|
||||
struct TreeContext
|
||||
{
|
||||
NodeIndex node;
|
||||
uint32_t generation = 0;
|
||||
|
||||
TLS_SERIALIZABLE(node, generation)
|
||||
};
|
||||
|
||||
///
|
||||
/// HashRatchet
|
||||
///
|
||||
|
||||
HashRatchet::HashRatchet(CipherSuite suite_in, bytes base_secret_in)
|
||||
: suite(suite_in)
|
||||
, next_secret(std::move(base_secret_in))
|
||||
, next_generation(0)
|
||||
, key_size(suite.hpke().aead.key_size)
|
||||
, nonce_size(suite.hpke().aead.nonce_size)
|
||||
, secret_size(suite.secret_size())
|
||||
{
|
||||
}
|
||||
|
||||
std::tuple<uint32_t, KeyAndNonce>
|
||||
HashRatchet::next()
|
||||
{
|
||||
auto generation = next_generation;
|
||||
auto key = suite.derive_tree_secret(next_secret, "key", generation, key_size);
|
||||
auto nonce =
|
||||
suite.derive_tree_secret(next_secret, "nonce", generation, nonce_size);
|
||||
auto secret =
|
||||
suite.derive_tree_secret(next_secret, "secret", generation, secret_size);
|
||||
|
||||
next_generation += 1;
|
||||
next_secret = secret;
|
||||
|
||||
cache[generation] = { key, nonce };
|
||||
return { generation, cache.at(generation) };
|
||||
}
|
||||
|
||||
// Note: This construction deliberately does not preserve the forward-secrecy
|
||||
// invariant, in that keys/nonces are not deleted after they are used.
|
||||
// Otherwise, it would not be possible for a node to send to itself. Keys can
|
||||
// be deleted once they are not needed by calling HashRatchet::erase().
|
||||
KeyAndNonce
|
||||
HashRatchet::get(uint32_t generation)
|
||||
{
|
||||
if (cache.count(generation) > 0) {
|
||||
auto out = cache.at(generation);
|
||||
return out;
|
||||
}
|
||||
|
||||
if (next_generation > generation) {
|
||||
throw ProtocolError("Request for expired key");
|
||||
}
|
||||
|
||||
while (next_generation <= generation) {
|
||||
next();
|
||||
}
|
||||
|
||||
return cache.at(generation);
|
||||
}
|
||||
|
||||
void
|
||||
HashRatchet::erase(uint32_t generation)
|
||||
{
|
||||
if (cache.count(generation) == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
cache.erase(generation);
|
||||
}
|
||||
|
||||
///
|
||||
/// SecretTree
|
||||
///
|
||||
|
||||
SecretTree::SecretTree(CipherSuite suite_in,
|
||||
LeafCount group_size_in,
|
||||
bytes encryption_secret_in)
|
||||
: suite(suite_in)
|
||||
, group_size(LeafCount::full(group_size_in))
|
||||
, root(NodeIndex::root(group_size))
|
||||
, secret_size(suite_in.secret_size())
|
||||
{
|
||||
secrets.emplace(root, std::move(encryption_secret_in));
|
||||
}
|
||||
|
||||
bytes
|
||||
SecretTree::get(LeafIndex sender)
|
||||
{
|
||||
static const auto context_left = from_ascii("left");
|
||||
static const auto context_right = from_ascii("right");
|
||||
auto node = NodeIndex(sender);
|
||||
|
||||
// Find an ancestor that is populated
|
||||
auto dirpath = node.dirpath(group_size);
|
||||
dirpath.insert(dirpath.begin(), node);
|
||||
dirpath.push_back(root);
|
||||
uint32_t curr = 0;
|
||||
for (; curr < dirpath.size(); ++curr) {
|
||||
auto i = dirpath.at(curr);
|
||||
if (secrets.count(i) > 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (curr > dirpath.size()) {
|
||||
throw InvalidParameterError("No secret found to derive base key");
|
||||
}
|
||||
|
||||
// Derive down
|
||||
for (; curr > 0; --curr) {
|
||||
auto curr_node = dirpath.at(curr);
|
||||
auto left = curr_node.left();
|
||||
auto right = curr_node.right();
|
||||
|
||||
auto& secret = secrets.at(curr_node);
|
||||
|
||||
const auto left_secret =
|
||||
suite.expand_with_label(secret, "tree", context_left, secret_size);
|
||||
const auto right_secret =
|
||||
suite.expand_with_label(secret, "tree", context_right, secret_size);
|
||||
|
||||
secrets.insert_or_assign(left, left_secret);
|
||||
secrets.insert_or_assign(right, right_secret);
|
||||
}
|
||||
|
||||
// Copy the leaf
|
||||
auto out = secrets.at(node);
|
||||
|
||||
// Zeroize along the direct path
|
||||
for (auto i : dirpath) {
|
||||
secrets.erase(i);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
///
|
||||
/// ReuseGuard
|
||||
///
|
||||
|
||||
static ReuseGuard
|
||||
new_reuse_guard()
|
||||
{
|
||||
auto random = random_bytes(4);
|
||||
auto guard = ReuseGuard();
|
||||
std::copy(random.begin(), random.end(), guard.begin());
|
||||
return guard;
|
||||
}
|
||||
|
||||
static void
|
||||
apply_reuse_guard(const ReuseGuard& guard, bytes& nonce)
|
||||
{
|
||||
for (size_t i = 0; i < guard.size(); i++) {
|
||||
nonce.at(i) ^= guard.at(i);
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// GroupKeySource
|
||||
///
|
||||
|
||||
GroupKeySource::GroupKeySource(CipherSuite suite_in,
|
||||
LeafCount group_size,
|
||||
bytes encryption_secret)
|
||||
: suite(suite_in)
|
||||
, secret_tree(suite, group_size, std::move(encryption_secret))
|
||||
{
|
||||
}
|
||||
|
||||
HashRatchet&
|
||||
GroupKeySource::chain(ContentType type, LeafIndex sender)
|
||||
{
|
||||
switch (type) {
|
||||
case ContentType::proposal:
|
||||
case ContentType::commit:
|
||||
return chain(RatchetType::handshake, sender);
|
||||
|
||||
case ContentType::application:
|
||||
return chain(RatchetType::application, sender);
|
||||
|
||||
default:
|
||||
throw InvalidParameterError("Invalid content type");
|
||||
}
|
||||
}
|
||||
|
||||
HashRatchet&
|
||||
GroupKeySource::chain(RatchetType type, LeafIndex sender)
|
||||
{
|
||||
auto key = Key{ type, sender };
|
||||
if (chains.count(key) > 0) {
|
||||
return chains[key];
|
||||
}
|
||||
|
||||
auto secret_size = suite.secret_size();
|
||||
auto leaf_secret = secret_tree.get(sender);
|
||||
|
||||
auto handshake_secret =
|
||||
suite.expand_with_label(leaf_secret, "handshake", {}, secret_size);
|
||||
auto application_secret =
|
||||
suite.expand_with_label(leaf_secret, "application", {}, secret_size);
|
||||
|
||||
chains.emplace(Key{ RatchetType::handshake, sender },
|
||||
HashRatchet{ suite, handshake_secret });
|
||||
chains.emplace(Key{ RatchetType::application, sender },
|
||||
HashRatchet{ suite, application_secret });
|
||||
|
||||
return chains[key];
|
||||
}
|
||||
|
||||
std::tuple<uint32_t, ReuseGuard, KeyAndNonce>
|
||||
GroupKeySource::next(ContentType type, LeafIndex sender)
|
||||
{
|
||||
auto [generation, keys] = chain(type, sender).next();
|
||||
|
||||
auto reuse_guard = new_reuse_guard();
|
||||
apply_reuse_guard(reuse_guard, keys.nonce);
|
||||
|
||||
return { generation, reuse_guard, keys };
|
||||
}
|
||||
|
||||
KeyAndNonce
|
||||
GroupKeySource::get(ContentType type,
|
||||
LeafIndex sender,
|
||||
uint32_t generation,
|
||||
ReuseGuard reuse_guard)
|
||||
{
|
||||
auto keys = chain(type, sender).get(generation);
|
||||
apply_reuse_guard(reuse_guard, keys.nonce);
|
||||
return keys;
|
||||
}
|
||||
|
||||
void
|
||||
GroupKeySource::erase(ContentType type, LeafIndex sender, uint32_t generation)
|
||||
{
|
||||
return chain(type, sender).erase(generation);
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque group_id<0..255>;
|
||||
// uint64 epoch;
|
||||
// ContentType content_type;
|
||||
// opaque authenticated_data<0..2^32-1>;
|
||||
// } ContentAAD;
|
||||
struct ContentAAD
|
||||
{
|
||||
const bytes& group_id;
|
||||
const epoch_t epoch;
|
||||
const ContentType content_type;
|
||||
const bytes& authenticated_data;
|
||||
|
||||
TLS_SERIALIZABLE(group_id, epoch, content_type, authenticated_data)
|
||||
};
|
||||
|
||||
///
|
||||
/// KeyScheduleEpoch
|
||||
///
|
||||
|
||||
struct PSKLabel
|
||||
{
|
||||
const PreSharedKeyID& id;
|
||||
uint16_t index;
|
||||
uint16_t count;
|
||||
|
||||
TLS_SERIALIZABLE(id, index, count);
|
||||
};
|
||||
|
||||
static bytes
|
||||
make_joiner_secret(CipherSuite suite,
|
||||
const bytes& context,
|
||||
const bytes& init_secret,
|
||||
const bytes& commit_secret)
|
||||
{
|
||||
auto pre_joiner_secret = suite.hpke().kdf.extract(init_secret, commit_secret);
|
||||
return suite.expand_with_label(
|
||||
pre_joiner_secret, "joiner", context, suite.secret_size());
|
||||
}
|
||||
|
||||
static bytes
|
||||
make_epoch_secret(CipherSuite suite,
|
||||
const bytes& joiner_secret,
|
||||
const bytes& psk_secret,
|
||||
const bytes& context)
|
||||
{
|
||||
auto member_secret = suite.hpke().kdf.extract(joiner_secret, psk_secret);
|
||||
return suite.expand_with_label(
|
||||
member_secret, "epoch", context, suite.secret_size());
|
||||
}
|
||||
|
||||
KeyScheduleEpoch
|
||||
KeyScheduleEpoch::joiner(CipherSuite suite_in,
|
||||
const bytes& joiner_secret,
|
||||
const std::vector<PSKWithSecret>& psks,
|
||||
const bytes& context)
|
||||
{
|
||||
return { suite_in, joiner_secret, make_psk_secret(suite_in, psks), context };
|
||||
}
|
||||
|
||||
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
|
||||
const bytes& joiner_secret,
|
||||
const bytes& psk_secret,
|
||||
const bytes& context)
|
||||
: suite(suite_in)
|
||||
, joiner_secret(joiner_secret)
|
||||
, epoch_secret(
|
||||
make_epoch_secret(suite_in, joiner_secret, psk_secret, context))
|
||||
, sender_data_secret(suite.derive_secret(epoch_secret, "sender data"))
|
||||
, encryption_secret(suite.derive_secret(epoch_secret, "encryption"))
|
||||
, exporter_secret(suite.derive_secret(epoch_secret, "exporter"))
|
||||
, epoch_authenticator(suite.derive_secret(epoch_secret, "authentication"))
|
||||
, external_secret(suite.derive_secret(epoch_secret, "external"))
|
||||
, confirmation_key(suite.derive_secret(epoch_secret, "confirm"))
|
||||
, membership_key(suite.derive_secret(epoch_secret, "membership"))
|
||||
, resumption_psk(suite.derive_secret(epoch_secret, "resumption"))
|
||||
, init_secret(suite.derive_secret(epoch_secret, "init"))
|
||||
, external_priv(HPKEPrivateKey::derive(suite, external_secret))
|
||||
{
|
||||
}
|
||||
|
||||
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in)
|
||||
: suite(suite_in)
|
||||
{
|
||||
}
|
||||
|
||||
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
|
||||
const bytes& init_secret,
|
||||
const bytes& context)
|
||||
: KeyScheduleEpoch(
|
||||
suite_in,
|
||||
make_joiner_secret(suite_in, context, init_secret, suite_in.zero()),
|
||||
{ /* no PSKs */ },
|
||||
context)
|
||||
{
|
||||
}
|
||||
|
||||
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
|
||||
const bytes& init_secret,
|
||||
const bytes& commit_secret,
|
||||
const bytes& psk_secret,
|
||||
const bytes& context)
|
||||
: KeyScheduleEpoch(
|
||||
suite_in,
|
||||
make_joiner_secret(suite_in, context, init_secret, commit_secret),
|
||||
psk_secret,
|
||||
context)
|
||||
{
|
||||
}
|
||||
|
||||
std::tuple<bytes, bytes>
|
||||
KeyScheduleEpoch::external_init(CipherSuite suite,
|
||||
const HPKEPublicKey& external_pub)
|
||||
{
|
||||
auto size = suite.secret_size();
|
||||
return external_pub.do_export(
|
||||
suite, {}, "MLS 1.0 external init secret", size);
|
||||
}
|
||||
|
||||
bytes
|
||||
KeyScheduleEpoch::receive_external_init(const bytes& kem_output) const
|
||||
{
|
||||
auto size = suite.secret_size();
|
||||
return external_priv.do_export(
|
||||
suite, {}, kem_output, "MLS 1.0 external init secret", size);
|
||||
}
|
||||
|
||||
KeyScheduleEpoch
|
||||
KeyScheduleEpoch::next(const bytes& commit_secret,
|
||||
const std::vector<PSKWithSecret>& psks,
|
||||
const std::optional<bytes>& force_init_secret,
|
||||
const bytes& context) const
|
||||
{
|
||||
return next_raw(
|
||||
commit_secret, make_psk_secret(suite, psks), force_init_secret, context);
|
||||
}
|
||||
|
||||
KeyScheduleEpoch
|
||||
KeyScheduleEpoch::next_raw(const bytes& commit_secret,
|
||||
const bytes& psk_secret,
|
||||
const std::optional<bytes>& force_init_secret,
|
||||
const bytes& context) const
|
||||
{
|
||||
auto actual_init_secret = init_secret;
|
||||
if (force_init_secret) {
|
||||
actual_init_secret = opt::get(force_init_secret);
|
||||
}
|
||||
|
||||
return { suite, actual_init_secret, commit_secret, psk_secret, context };
|
||||
}
|
||||
|
||||
GroupKeySource
|
||||
KeyScheduleEpoch::encryption_keys(LeafCount size) const
|
||||
{
|
||||
return { suite, size, encryption_secret };
|
||||
}
|
||||
|
||||
bytes
|
||||
KeyScheduleEpoch::confirmation_tag(const bytes& confirmed_transcript_hash) const
|
||||
{
|
||||
return suite.digest().hmac(confirmation_key, confirmed_transcript_hash);
|
||||
}
|
||||
|
||||
bytes
|
||||
KeyScheduleEpoch::do_export(const std::string& label,
|
||||
const bytes& context,
|
||||
size_t size) const
|
||||
{
|
||||
auto secret = suite.derive_secret(exporter_secret, label);
|
||||
auto context_hash = suite.digest().hash(context);
|
||||
return suite.expand_with_label(secret, "exported", context_hash, size);
|
||||
}
|
||||
|
||||
PSKWithSecret
|
||||
KeyScheduleEpoch::resumption_psk_w_secret(ResumptionPSKUsage usage,
|
||||
const bytes& group_id,
|
||||
epoch_t epoch)
|
||||
{
|
||||
auto nonce = random_bytes(suite.secret_size());
|
||||
auto psk = ResumptionPSK{ usage, group_id, epoch };
|
||||
return { { psk, nonce }, resumption_psk };
|
||||
}
|
||||
|
||||
bytes
|
||||
KeyScheduleEpoch::make_psk_secret(CipherSuite suite,
|
||||
const std::vector<PSKWithSecret>& psks)
|
||||
{
|
||||
auto psk_secret = suite.zero();
|
||||
auto count = uint16_t(psks.size());
|
||||
auto index = uint16_t(0);
|
||||
for (const auto& psk : psks) {
|
||||
auto psk_extracted = suite.hpke().kdf.extract(suite.zero(), psk.secret);
|
||||
auto psk_label = tls::marshal(PSKLabel{ psk.id, index, count });
|
||||
auto psk_input = suite.expand_with_label(
|
||||
psk_extracted, "derived psk", psk_label, suite.secret_size());
|
||||
psk_secret = suite.hpke().kdf.extract(psk_input, psk_secret);
|
||||
index += 1;
|
||||
}
|
||||
return psk_secret;
|
||||
}
|
||||
|
||||
bytes
|
||||
KeyScheduleEpoch::welcome_secret(CipherSuite suite,
|
||||
const bytes& joiner_secret,
|
||||
const std::vector<PSKWithSecret>& psks)
|
||||
{
|
||||
auto psk_secret = make_psk_secret(suite, psks);
|
||||
return welcome_secret_raw(suite, joiner_secret, psk_secret);
|
||||
}
|
||||
|
||||
bytes
|
||||
KeyScheduleEpoch::welcome_secret_raw(CipherSuite suite,
|
||||
const bytes& joiner_secret,
|
||||
const bytes& psk_secret)
|
||||
{
|
||||
auto extract = suite.hpke().kdf.extract(joiner_secret, psk_secret);
|
||||
return suite.derive_secret(extract, "welcome");
|
||||
}
|
||||
|
||||
KeyAndNonce
|
||||
KeyScheduleEpoch::sender_data_keys(CipherSuite suite,
|
||||
const bytes& sender_data_secret,
|
||||
const bytes& ciphertext)
|
||||
{
|
||||
auto sample_size = suite.secret_size();
|
||||
auto sample = bytes(sample_size);
|
||||
if (ciphertext.size() <= sample_size) {
|
||||
sample = ciphertext;
|
||||
} else {
|
||||
sample = ciphertext.slice(0, sample_size);
|
||||
}
|
||||
|
||||
auto key_size = suite.hpke().aead.key_size;
|
||||
auto nonce_size = suite.hpke().aead.nonce_size;
|
||||
return {
|
||||
suite.expand_with_label(sender_data_secret, "key", sample, key_size),
|
||||
suite.expand_with_label(sender_data_secret, "nonce", sample, nonce_size),
|
||||
};
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const KeyScheduleEpoch& lhs, const KeyScheduleEpoch& rhs)
|
||||
{
|
||||
auto epoch_secret = (lhs.epoch_secret == rhs.epoch_secret);
|
||||
auto sender_data_secret = (lhs.sender_data_secret == rhs.sender_data_secret);
|
||||
auto encryption_secret = (lhs.encryption_secret == rhs.encryption_secret);
|
||||
auto exporter_secret = (lhs.exporter_secret == rhs.exporter_secret);
|
||||
auto confirmation_key = (lhs.confirmation_key == rhs.confirmation_key);
|
||||
auto init_secret = (lhs.init_secret == rhs.init_secret);
|
||||
auto external_priv = (lhs.external_priv == rhs.external_priv);
|
||||
|
||||
return epoch_secret && sender_data_secret && encryption_secret &&
|
||||
exporter_secret && confirmation_key && init_secret && external_priv;
|
||||
}
|
||||
|
||||
// struct {
|
||||
// WireFormat wire_format;
|
||||
// GroupContent content; // with content.content_type == commit
|
||||
// opaque signature<V>;
|
||||
// } ConfirmedTranscriptHashInput;
|
||||
struct ConfirmedTranscriptHashInput
|
||||
{
|
||||
WireFormat wire_format;
|
||||
const GroupContent& content;
|
||||
const bytes& signature;
|
||||
|
||||
TLS_SERIALIZABLE(wire_format, content, signature)
|
||||
};
|
||||
|
||||
// struct {
|
||||
// MAC confirmation_tag;
|
||||
// } InterimTranscriptHashInput;
|
||||
struct InterimTranscriptHashInput
|
||||
{
|
||||
bytes confirmation_tag;
|
||||
|
||||
TLS_SERIALIZABLE(confirmation_tag)
|
||||
};
|
||||
|
||||
TranscriptHash::TranscriptHash(CipherSuite suite_in)
|
||||
: suite(suite_in)
|
||||
{
|
||||
}
|
||||
|
||||
TranscriptHash::TranscriptHash(CipherSuite suite_in,
|
||||
bytes confirmed_in,
|
||||
const bytes& confirmation_tag)
|
||||
: suite(suite_in)
|
||||
, confirmed(std::move(confirmed_in))
|
||||
{
|
||||
update_interim(confirmation_tag);
|
||||
}
|
||||
|
||||
void
|
||||
TranscriptHash::update(const AuthenticatedContent& content_auth)
|
||||
{
|
||||
update_confirmed(content_auth);
|
||||
update_interim(content_auth);
|
||||
}
|
||||
|
||||
void
|
||||
TranscriptHash::update_confirmed(const AuthenticatedContent& content_auth)
|
||||
{
|
||||
const auto transcript =
|
||||
interim + content_auth.confirmed_transcript_hash_input();
|
||||
confirmed = suite.digest().hash(transcript);
|
||||
}
|
||||
|
||||
void
|
||||
TranscriptHash::update_interim(const bytes& confirmation_tag)
|
||||
{
|
||||
const auto transcript = confirmed + tls::marshal(confirmation_tag);
|
||||
interim = suite.digest().hash(transcript);
|
||||
}
|
||||
|
||||
void
|
||||
TranscriptHash::update_interim(const AuthenticatedContent& content_auth)
|
||||
{
|
||||
const auto transcript =
|
||||
confirmed + content_auth.interim_transcript_hash_input();
|
||||
interim = suite.digest().hash(transcript);
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const TranscriptHash& lhs, const TranscriptHash& rhs)
|
||||
{
|
||||
auto confirmed = (lhs.confirmed == rhs.confirmed);
|
||||
auto interim = (lhs.interim == rhs.interim);
|
||||
return confirmed && interim;
|
||||
}
|
||||
|
||||
} // namespace mlspp
|
||||
947
DPP/mlspp/src/messages.cpp
Executable file
947
DPP/mlspp/src/messages.cpp
Executable file
@@ -0,0 +1,947 @@
|
||||
#include <mls/key_schedule.h>
|
||||
#include <mls/messages.h>
|
||||
#include <mls/state.h>
|
||||
#include <mls/treekem.h>
|
||||
|
||||
#include "grease.h"
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
// Extensions
|
||||
|
||||
const Extension::Type ExternalPubExtension::type = ExtensionType::external_pub;
|
||||
const Extension::Type RatchetTreeExtension::type = ExtensionType::ratchet_tree;
|
||||
const Extension::Type ExternalSendersExtension::type =
|
||||
ExtensionType::external_senders;
|
||||
const Extension::Type SFrameParameters::type = ExtensionType::sframe_parameters;
|
||||
const Extension::Type SFrameCapabilities::type =
|
||||
ExtensionType::sframe_parameters;
|
||||
|
||||
bool
|
||||
SFrameCapabilities::compatible(const SFrameParameters& params) const
|
||||
{
|
||||
return stdx::contains(cipher_suites, params.cipher_suite);
|
||||
}
|
||||
|
||||
// GroupContext
|
||||
|
||||
GroupContext::GroupContext(CipherSuite cipher_suite_in,
|
||||
bytes group_id_in,
|
||||
epoch_t epoch_in,
|
||||
bytes tree_hash_in,
|
||||
bytes confirmed_transcript_hash_in,
|
||||
ExtensionList extensions_in)
|
||||
: cipher_suite(cipher_suite_in)
|
||||
, group_id(std::move(group_id_in))
|
||||
, epoch(epoch_in)
|
||||
, tree_hash(std::move(tree_hash_in))
|
||||
, confirmed_transcript_hash(std::move(confirmed_transcript_hash_in))
|
||||
, extensions(std::move(extensions_in))
|
||||
{
|
||||
}
|
||||
|
||||
// GroupInfo
|
||||
|
||||
GroupInfo::GroupInfo(GroupContext group_context_in,
|
||||
ExtensionList extensions_in,
|
||||
bytes confirmation_tag_in)
|
||||
: group_context(std::move(group_context_in))
|
||||
, extensions(std::move(extensions_in))
|
||||
, confirmation_tag(std::move(confirmation_tag_in))
|
||||
, signer(0)
|
||||
{
|
||||
grease(extensions);
|
||||
}
|
||||
|
||||
struct GroupInfoTBS
|
||||
{
|
||||
GroupContext group_context;
|
||||
ExtensionList extensions;
|
||||
bytes confirmation_tag;
|
||||
LeafIndex signer;
|
||||
|
||||
TLS_SERIALIZABLE(group_context, extensions, confirmation_tag, signer)
|
||||
};
|
||||
|
||||
bytes
|
||||
GroupInfo::to_be_signed() const
|
||||
{
|
||||
return tls::marshal(
|
||||
GroupInfoTBS{ group_context, extensions, confirmation_tag, signer });
|
||||
}
|
||||
|
||||
void
|
||||
GroupInfo::sign(const TreeKEMPublicKey& tree,
|
||||
LeafIndex signer_index,
|
||||
const SignaturePrivateKey& priv)
|
||||
{
|
||||
auto maybe_leaf = tree.leaf_node(signer_index);
|
||||
if (!maybe_leaf) {
|
||||
throw InvalidParameterError("Cannot sign from a blank leaf");
|
||||
}
|
||||
|
||||
if (priv.public_key != opt::get(maybe_leaf).signature_key) {
|
||||
throw InvalidParameterError("Bad key for index");
|
||||
}
|
||||
|
||||
signer = signer_index;
|
||||
signature = priv.sign(tree.suite, sign_label::group_info, to_be_signed());
|
||||
}
|
||||
|
||||
bool
|
||||
GroupInfo::verify(const TreeKEMPublicKey& tree) const
|
||||
{
|
||||
auto maybe_leaf = tree.leaf_node(signer);
|
||||
if (!maybe_leaf) {
|
||||
throw InvalidParameterError("Signer not found");
|
||||
}
|
||||
|
||||
const auto& leaf = opt::get(maybe_leaf);
|
||||
return verify(leaf.signature_key);
|
||||
}
|
||||
|
||||
void
|
||||
GroupInfo::sign(LeafIndex signer_index, const SignaturePrivateKey& priv)
|
||||
{
|
||||
signer = signer_index;
|
||||
signature = priv.sign(
|
||||
group_context.cipher_suite, sign_label::group_info, to_be_signed());
|
||||
}
|
||||
|
||||
bool
|
||||
GroupInfo::verify(const SignaturePublicKey& pub) const
|
||||
{
|
||||
return pub.verify(group_context.cipher_suite,
|
||||
sign_label::group_info,
|
||||
to_be_signed(),
|
||||
signature);
|
||||
}
|
||||
|
||||
// Welcome
|
||||
|
||||
Welcome::Welcome()
|
||||
: cipher_suite(CipherSuite::ID::unknown)
|
||||
{
|
||||
}
|
||||
|
||||
Welcome::Welcome(CipherSuite suite,
|
||||
const bytes& joiner_secret,
|
||||
const std::vector<PSKWithSecret>& psks,
|
||||
const GroupInfo& group_info)
|
||||
: cipher_suite(suite)
|
||||
, _joiner_secret(joiner_secret)
|
||||
{
|
||||
// Cache the list of PSK IDs
|
||||
for (const auto& psk : psks) {
|
||||
_psks.psks.push_back(psk.id);
|
||||
}
|
||||
|
||||
// Pre-encrypt the GroupInfo
|
||||
auto [key, nonce] = group_info_key_nonce(suite, joiner_secret, psks);
|
||||
auto group_info_data = tls::marshal(group_info);
|
||||
encrypted_group_info =
|
||||
cipher_suite.hpke().aead.seal(key, nonce, {}, group_info_data);
|
||||
}
|
||||
|
||||
std::optional<int>
|
||||
Welcome::find(const KeyPackage& kp) const
|
||||
{
|
||||
auto ref = kp.ref();
|
||||
for (size_t i = 0; i < secrets.size(); i++) {
|
||||
if (ref == secrets[i].new_member) {
|
||||
return static_cast<int>(i);
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void
|
||||
Welcome::encrypt(const KeyPackage& kp, const std::optional<bytes>& path_secret)
|
||||
{
|
||||
auto gs = GroupSecrets{ _joiner_secret, std::nullopt, _psks };
|
||||
if (path_secret) {
|
||||
gs.path_secret = GroupSecrets::PathSecret{ opt::get(path_secret) };
|
||||
}
|
||||
|
||||
auto gs_data = tls::marshal(gs);
|
||||
auto enc_gs = kp.init_key.encrypt(
|
||||
kp.cipher_suite, encrypt_label::welcome, encrypted_group_info, gs_data);
|
||||
secrets.push_back({ kp.ref(), enc_gs });
|
||||
}
|
||||
|
||||
GroupSecrets
|
||||
Welcome::decrypt_secrets(int kp_index, const HPKEPrivateKey& init_priv) const
|
||||
{
|
||||
auto secrets_data =
|
||||
init_priv.decrypt(cipher_suite,
|
||||
encrypt_label::welcome,
|
||||
encrypted_group_info,
|
||||
secrets.at(kp_index).encrypted_group_secrets);
|
||||
return tls::get<GroupSecrets>(secrets_data);
|
||||
}
|
||||
|
||||
GroupInfo
|
||||
Welcome::decrypt(const bytes& joiner_secret,
|
||||
const std::vector<PSKWithSecret>& psks) const
|
||||
{
|
||||
auto [key, nonce] = group_info_key_nonce(cipher_suite, joiner_secret, psks);
|
||||
auto group_info_data =
|
||||
cipher_suite.hpke().aead.open(key, nonce, {}, encrypted_group_info);
|
||||
if (!group_info_data) {
|
||||
throw ProtocolError("Welcome decryption failed");
|
||||
}
|
||||
|
||||
return tls::get<GroupInfo>(opt::get(group_info_data));
|
||||
}
|
||||
|
||||
KeyAndNonce
|
||||
Welcome::group_info_key_nonce(CipherSuite suite,
|
||||
const bytes& joiner_secret,
|
||||
const std::vector<PSKWithSecret>& psks)
|
||||
{
|
||||
auto welcome_secret =
|
||||
KeyScheduleEpoch::welcome_secret(suite, joiner_secret, psks);
|
||||
|
||||
// XXX(RLB): These used to be done with ExpandWithLabel. Should we do that
|
||||
// instead, for better domain separation? (In particular, including "mls10")
|
||||
// That is what we do for the sender data key/nonce.
|
||||
auto key =
|
||||
suite.expand_with_label(welcome_secret, "key", {}, suite.key_size());
|
||||
auto nonce =
|
||||
suite.expand_with_label(welcome_secret, "nonce", {}, suite.nonce_size());
|
||||
return { std::move(key), std::move(nonce) };
|
||||
}
|
||||
|
||||
// Commit
|
||||
std::optional<bytes>
|
||||
Commit::valid_external() const
|
||||
{
|
||||
// External Commits MUST contain a path field (and is therefore a "full"
|
||||
// Commit). The joiner is added at the leftmost free leaf node (just as if
|
||||
// they were added with an Add proposal), and the path is calculated relative
|
||||
// to that leaf node.
|
||||
//
|
||||
// The Commit MUST NOT include any proposals by reference, since an external
|
||||
// joiner cannot determine the validity of proposals sent within the group
|
||||
const auto all_by_value = stdx::all_of(proposals, [](const auto& p) {
|
||||
return var::holds_alternative<Proposal>(p.content);
|
||||
});
|
||||
if (!path || !all_by_value) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const auto ext_init_ptr = stdx::find_if(proposals, [](const auto& p) {
|
||||
const auto proposal = var::get<Proposal>(p.content);
|
||||
return proposal.proposal_type() == ProposalType::external_init;
|
||||
});
|
||||
if (ext_init_ptr == proposals.end()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const auto& ext_init_proposal = var::get<Proposal>(ext_init_ptr->content);
|
||||
const auto& ext_init = var::get<ExternalInit>(ext_init_proposal.content);
|
||||
return ext_init.kem_output;
|
||||
}
|
||||
|
||||
// PublicMessage
|
||||
Proposal::Type
|
||||
Proposal::proposal_type() const
|
||||
{
|
||||
return tls::variant<ProposalType>::type(content).val;
|
||||
}
|
||||
|
||||
SenderType
|
||||
Sender::sender_type() const
|
||||
{
|
||||
return tls::variant<SenderType>::type(sender);
|
||||
}
|
||||
|
||||
tls::ostream&
|
||||
operator<<(tls::ostream& str, const GroupContentAuthData& obj)
|
||||
{
|
||||
switch (obj.content_type) {
|
||||
case ContentType::proposal:
|
||||
case ContentType::application:
|
||||
return str << obj.signature;
|
||||
|
||||
case ContentType::commit:
|
||||
return str << obj.signature << opt::get(obj.confirmation_tag);
|
||||
|
||||
default:
|
||||
throw InvalidParameterError("Invalid content type");
|
||||
}
|
||||
}
|
||||
|
||||
tls::istream&
|
||||
operator>>(tls::istream& str, GroupContentAuthData& obj)
|
||||
{
|
||||
switch (obj.content_type) {
|
||||
case ContentType::proposal:
|
||||
case ContentType::application:
|
||||
return str >> obj.signature;
|
||||
|
||||
case ContentType::commit:
|
||||
obj.confirmation_tag.emplace();
|
||||
return str >> obj.signature >> opt::get(obj.confirmation_tag);
|
||||
|
||||
default:
|
||||
throw InvalidParameterError("Invalid content type");
|
||||
}
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const GroupContentAuthData& lhs, const GroupContentAuthData& rhs)
|
||||
{
|
||||
return lhs.content_type == rhs.content_type &&
|
||||
lhs.signature == rhs.signature &&
|
||||
lhs.confirmation_tag == rhs.confirmation_tag;
|
||||
}
|
||||
|
||||
GroupContent::GroupContent(bytes group_id_in,
|
||||
epoch_t epoch_in,
|
||||
Sender sender_in,
|
||||
bytes authenticated_data_in,
|
||||
RawContent content_in)
|
||||
: group_id(std::move(group_id_in))
|
||||
, epoch(epoch_in)
|
||||
, sender(sender_in)
|
||||
, authenticated_data(std::move(authenticated_data_in))
|
||||
, content(std::move(content_in))
|
||||
{
|
||||
}
|
||||
|
||||
GroupContent::GroupContent(bytes group_id_in,
|
||||
epoch_t epoch_in,
|
||||
Sender sender_in,
|
||||
bytes authenticated_data_in,
|
||||
ContentType content_type)
|
||||
: group_id(std::move(group_id_in))
|
||||
, epoch(epoch_in)
|
||||
, sender(sender_in)
|
||||
, authenticated_data(std::move(authenticated_data_in))
|
||||
{
|
||||
switch (content_type) {
|
||||
case ContentType::commit:
|
||||
content.emplace<Commit>();
|
||||
break;
|
||||
|
||||
case ContentType::proposal:
|
||||
content.emplace<Proposal>();
|
||||
break;
|
||||
|
||||
case ContentType::application:
|
||||
content.emplace<ApplicationData>();
|
||||
break;
|
||||
|
||||
default:
|
||||
throw InvalidParameterError("Invalid content type");
|
||||
}
|
||||
}
|
||||
|
||||
ContentType
|
||||
GroupContent::content_type() const
|
||||
{
|
||||
return tls::variant<ContentType>::type(content);
|
||||
}
|
||||
|
||||
AuthenticatedContent
|
||||
AuthenticatedContent::sign(WireFormat wire_format,
|
||||
GroupContent content,
|
||||
CipherSuite suite,
|
||||
const SignaturePrivateKey& sig_priv,
|
||||
const std::optional<GroupContext>& context)
|
||||
{
|
||||
if (wire_format == WireFormat::mls_public_message &&
|
||||
content.content_type() == ContentType::application) {
|
||||
throw InvalidParameterError(
|
||||
"Application data cannot be sent as PublicMessage");
|
||||
}
|
||||
|
||||
auto content_auth = AuthenticatedContent{ wire_format, std::move(content) };
|
||||
auto tbs = content_auth.to_be_signed(context);
|
||||
content_auth.auth.signature =
|
||||
sig_priv.sign(suite, sign_label::mls_content, tbs);
|
||||
return content_auth;
|
||||
}
|
||||
|
||||
bool
|
||||
AuthenticatedContent::verify(CipherSuite suite,
|
||||
const SignaturePublicKey& sig_pub,
|
||||
const std::optional<GroupContext>& context) const
|
||||
{
|
||||
if (wire_format == WireFormat::mls_public_message &&
|
||||
content.content_type() == ContentType::application) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tbs = to_be_signed(context);
|
||||
return sig_pub.verify(suite, sign_label::mls_content, tbs, auth.signature);
|
||||
}
|
||||
|
||||
struct ConfirmedTranscriptHashInput
|
||||
{
|
||||
WireFormat wire_format;
|
||||
const GroupContent& content;
|
||||
const bytes& signature;
|
||||
|
||||
TLS_SERIALIZABLE(wire_format, content, signature);
|
||||
};
|
||||
|
||||
struct InterimTranscriptHashInput
|
||||
{
|
||||
const bytes& confirmation_tag;
|
||||
|
||||
TLS_SERIALIZABLE(confirmation_tag);
|
||||
};
|
||||
|
||||
bytes
|
||||
AuthenticatedContent::confirmed_transcript_hash_input() const
|
||||
{
|
||||
return tls::marshal(ConfirmedTranscriptHashInput{
|
||||
wire_format,
|
||||
content,
|
||||
auth.signature,
|
||||
});
|
||||
}
|
||||
|
||||
bytes
|
||||
AuthenticatedContent::interim_transcript_hash_input() const
|
||||
{
|
||||
return tls::marshal(
|
||||
InterimTranscriptHashInput{ opt::get(auth.confirmation_tag) });
|
||||
}
|
||||
|
||||
void
|
||||
AuthenticatedContent::set_confirmation_tag(const bytes& confirmation_tag)
|
||||
{
|
||||
auth.confirmation_tag = confirmation_tag;
|
||||
}
|
||||
|
||||
bool
|
||||
AuthenticatedContent::check_confirmation_tag(
|
||||
const bytes& confirmation_tag) const
|
||||
{
|
||||
return confirmation_tag == opt::get(auth.confirmation_tag);
|
||||
}
|
||||
|
||||
tls::ostream&
|
||||
operator<<(tls::ostream& str, const AuthenticatedContent& obj)
|
||||
{
|
||||
return str << obj.wire_format << obj.content << obj.auth;
|
||||
}
|
||||
|
||||
tls::istream&
|
||||
operator>>(tls::istream& str, AuthenticatedContent& obj)
|
||||
{
|
||||
str >> obj.wire_format >> obj.content;
|
||||
|
||||
obj.auth.content_type = obj.content.content_type();
|
||||
return str >> obj.auth;
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const AuthenticatedContent& lhs, const AuthenticatedContent& rhs)
|
||||
{
|
||||
return lhs.wire_format == rhs.wire_format && lhs.content == rhs.content &&
|
||||
lhs.auth == rhs.auth;
|
||||
}
|
||||
|
||||
AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in,
|
||||
GroupContent content_in)
|
||||
: wire_format(wire_format_in)
|
||||
, content(std::move(content_in))
|
||||
{
|
||||
auth.content_type = content.content_type();
|
||||
}
|
||||
|
||||
AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in,
|
||||
GroupContent content_in,
|
||||
GroupContentAuthData auth_in)
|
||||
: wire_format(wire_format_in)
|
||||
, content(std::move(content_in))
|
||||
, auth(std::move(auth_in))
|
||||
{
|
||||
}
|
||||
|
||||
const AuthenticatedContent&
|
||||
ValidatedContent::authenticated_content() const
|
||||
{
|
||||
return content_auth;
|
||||
}
|
||||
|
||||
ValidatedContent::ValidatedContent(AuthenticatedContent content_auth_in)
|
||||
: content_auth(std::move(content_auth_in))
|
||||
{
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const ValidatedContent& lhs, const ValidatedContent& rhs)
|
||||
{
|
||||
return lhs.content_auth == rhs.content_auth;
|
||||
}
|
||||
|
||||
struct GroupContentTBS
|
||||
{
|
||||
WireFormat wire_format = WireFormat::reserved;
|
||||
const GroupContent& content;
|
||||
const std::optional<GroupContext>& context;
|
||||
};
|
||||
|
||||
static tls::ostream&
|
||||
operator<<(tls::ostream& str, const GroupContentTBS& obj)
|
||||
{
|
||||
str << ProtocolVersion::mls10 << obj.wire_format << obj.content;
|
||||
|
||||
switch (obj.content.sender.sender_type()) {
|
||||
case SenderType::member:
|
||||
case SenderType::new_member_commit:
|
||||
str << opt::get(obj.context);
|
||||
break;
|
||||
|
||||
case SenderType::external:
|
||||
case SenderType::new_member_proposal:
|
||||
break;
|
||||
|
||||
default:
|
||||
throw InvalidParameterError("Invalid sender type");
|
||||
}
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
bytes
|
||||
AuthenticatedContent::to_be_signed(
|
||||
const std::optional<GroupContext>& context) const
|
||||
{
|
||||
return tls::marshal(GroupContentTBS{
|
||||
wire_format,
|
||||
content,
|
||||
context,
|
||||
});
|
||||
}
|
||||
|
||||
PublicMessage
|
||||
PublicMessage::protect(AuthenticatedContent content_auth,
|
||||
CipherSuite suite,
|
||||
const std::optional<bytes>& membership_key,
|
||||
const std::optional<GroupContext>& context)
|
||||
{
|
||||
auto pt = PublicMessage(std::move(content_auth));
|
||||
|
||||
// Add the membership_mac if required
|
||||
switch (pt.content.sender.sender_type()) {
|
||||
case SenderType::member:
|
||||
pt.membership_tag =
|
||||
pt.membership_mac(suite, opt::get(membership_key), context);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return pt;
|
||||
}
|
||||
|
||||
std::optional<ValidatedContent>
|
||||
PublicMessage::unprotect(CipherSuite suite,
|
||||
const std::optional<bytes>& membership_key,
|
||||
const std::optional<GroupContext>& context) const
|
||||
{
|
||||
// Verify the membership_tag if the message was sent within the group
|
||||
switch (content.sender.sender_type()) {
|
||||
case SenderType::member: {
|
||||
auto candidate = membership_mac(suite, opt::get(membership_key), context);
|
||||
if (candidate != opt::get(membership_tag)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return { { AuthenticatedContent{
|
||||
WireFormat::mls_public_message,
|
||||
content,
|
||||
auth,
|
||||
} } };
|
||||
}
|
||||
|
||||
bool
|
||||
PublicMessage::contains(const AuthenticatedContent& content_auth) const
|
||||
{
|
||||
return content == content_auth.content && auth == content_auth.auth;
|
||||
}
|
||||
|
||||
AuthenticatedContent
|
||||
PublicMessage::authenticated_content() const
|
||||
{
|
||||
auto auth_content = AuthenticatedContent{};
|
||||
auth_content.wire_format = WireFormat::mls_public_message;
|
||||
auth_content.content = content;
|
||||
auth_content.auth = auth;
|
||||
return auth_content;
|
||||
}
|
||||
|
||||
PublicMessage::PublicMessage(AuthenticatedContent content_auth)
|
||||
: content(std::move(content_auth.content))
|
||||
, auth(std::move(content_auth.auth))
|
||||
{
|
||||
if (content_auth.wire_format != WireFormat::mls_public_message) {
|
||||
throw InvalidParameterError("Wire format mismatch (not mls_plaintext)");
|
||||
}
|
||||
}
|
||||
|
||||
struct GroupContentTBM
|
||||
{
|
||||
GroupContentTBS content_tbs;
|
||||
GroupContentAuthData auth;
|
||||
|
||||
TLS_SERIALIZABLE(content_tbs, auth);
|
||||
};
|
||||
|
||||
bytes
|
||||
PublicMessage::membership_mac(CipherSuite suite,
|
||||
const bytes& membership_key,
|
||||
const std::optional<GroupContext>& context) const
|
||||
{
|
||||
auto tbm = tls::marshal(GroupContentTBM{
|
||||
{ WireFormat::mls_public_message, content, context },
|
||||
auth,
|
||||
});
|
||||
|
||||
return suite.digest().hmac(membership_key, tbm);
|
||||
}
|
||||
|
||||
tls::ostream&
|
||||
operator<<(tls::ostream& str, const PublicMessage& obj)
|
||||
{
|
||||
switch (obj.content.sender.sender_type()) {
|
||||
case SenderType::member:
|
||||
return str << obj.content << obj.auth << opt::get(obj.membership_tag);
|
||||
|
||||
case SenderType::external:
|
||||
case SenderType::new_member_proposal:
|
||||
case SenderType::new_member_commit:
|
||||
return str << obj.content << obj.auth;
|
||||
|
||||
default:
|
||||
throw InvalidParameterError("Invalid sender type");
|
||||
}
|
||||
}
|
||||
|
||||
tls::istream&
|
||||
operator>>(tls::istream& str, PublicMessage& obj)
|
||||
{
|
||||
str >> obj.content;
|
||||
|
||||
obj.auth.content_type = obj.content.content_type();
|
||||
str >> obj.auth;
|
||||
|
||||
if (obj.content.sender.sender_type() == SenderType::member) {
|
||||
obj.membership_tag.emplace();
|
||||
str >> opt::get(obj.membership_tag);
|
||||
}
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const PublicMessage& lhs, const PublicMessage& rhs)
|
||||
{
|
||||
return lhs.content == rhs.content && lhs.auth == rhs.auth &&
|
||||
lhs.membership_tag == rhs.membership_tag;
|
||||
}
|
||||
|
||||
bool
|
||||
operator!=(const PublicMessage& lhs, const PublicMessage& rhs)
|
||||
{
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
static bytes
|
||||
marshal_ciphertext_content(const GroupContent& content,
|
||||
const GroupContentAuthData& auth,
|
||||
size_t padding_size)
|
||||
{
|
||||
auto w = tls::ostream{};
|
||||
var::visit([&w](const auto& val) { w << val; }, content.content);
|
||||
w << auth;
|
||||
w.write_raw(bytes(padding_size, 0));
|
||||
return w.bytes();
|
||||
}
|
||||
|
||||
static void
|
||||
unmarshal_ciphertext_content(const bytes& content_pt,
|
||||
GroupContent& content,
|
||||
GroupContentAuthData& auth)
|
||||
{
|
||||
auto r = tls::istream(content_pt);
|
||||
|
||||
var::visit([&r](auto& val) { r >> val; }, content.content);
|
||||
r >> auth;
|
||||
|
||||
const auto padding = r.bytes();
|
||||
const auto nonzero = [](const auto& x) { return x != 0; };
|
||||
if (stdx::any_of(padding, nonzero)) {
|
||||
throw ProtocolError("Malformed AuthenticatedContentTBE padding");
|
||||
}
|
||||
}
|
||||
|
||||
struct ContentAAD
|
||||
{
|
||||
const bytes& group_id;
|
||||
const epoch_t epoch;
|
||||
const ContentType content_type;
|
||||
const bytes& authenticated_data;
|
||||
|
||||
TLS_SERIALIZABLE(group_id, epoch, content_type, authenticated_data)
|
||||
};
|
||||
|
||||
struct SenderData
|
||||
{
|
||||
LeafIndex sender{ 0 };
|
||||
uint32_t generation{ 0 };
|
||||
ReuseGuard reuse_guard{ 0, 0, 0, 0 };
|
||||
|
||||
TLS_SERIALIZABLE(sender, generation, reuse_guard)
|
||||
};
|
||||
|
||||
struct SenderDataAAD
|
||||
{
|
||||
const bytes& group_id;
|
||||
const epoch_t epoch;
|
||||
const ContentType content_type;
|
||||
|
||||
TLS_SERIALIZABLE(group_id, epoch, content_type)
|
||||
};
|
||||
|
||||
PrivateMessage
|
||||
PrivateMessage::protect(AuthenticatedContent content_auth,
|
||||
CipherSuite suite,
|
||||
GroupKeySource& keys,
|
||||
const bytes& sender_data_secret,
|
||||
size_t padding_size)
|
||||
{
|
||||
// Pull keys from the secret tree
|
||||
auto index =
|
||||
var::get<MemberSender>(content_auth.content.sender.sender).sender;
|
||||
auto content_type = content_auth.content.content_type();
|
||||
auto [generation, reuse_guard, content_keys] = keys.next(content_type, index);
|
||||
|
||||
// Encrypt the content
|
||||
auto content_pt = marshal_ciphertext_content(
|
||||
content_auth.content, content_auth.auth, padding_size);
|
||||
auto content_aad = tls::marshal(ContentAAD{
|
||||
content_auth.content.group_id,
|
||||
content_auth.content.epoch,
|
||||
content_auth.content.content_type(),
|
||||
content_auth.content.authenticated_data,
|
||||
});
|
||||
|
||||
auto content_ct = suite.hpke().aead.seal(
|
||||
content_keys.key, content_keys.nonce, content_aad, content_pt);
|
||||
|
||||
// Encrypt the sender data
|
||||
auto sender_index =
|
||||
var::get<MemberSender>(content_auth.content.sender.sender).sender;
|
||||
auto sender_data_pt = tls::marshal(SenderData{
|
||||
sender_index,
|
||||
generation,
|
||||
reuse_guard,
|
||||
});
|
||||
auto sender_data_aad = tls::marshal(SenderDataAAD{
|
||||
content_auth.content.group_id,
|
||||
content_auth.content.epoch,
|
||||
content_auth.content.content_type(),
|
||||
});
|
||||
|
||||
auto sender_data_keys =
|
||||
KeyScheduleEpoch::sender_data_keys(suite, sender_data_secret, content_ct);
|
||||
|
||||
auto sender_data_ct = suite.hpke().aead.seal(sender_data_keys.key,
|
||||
sender_data_keys.nonce,
|
||||
sender_data_aad,
|
||||
sender_data_pt);
|
||||
|
||||
return PrivateMessage{
|
||||
std::move(content_auth.content),
|
||||
std::move(sender_data_ct),
|
||||
std::move(content_ct),
|
||||
};
|
||||
}
|
||||
|
||||
std::optional<ValidatedContent>
|
||||
PrivateMessage::unprotect(CipherSuite suite,
|
||||
GroupKeySource& keys,
|
||||
const bytes& sender_data_secret) const
|
||||
{
|
||||
// Decrypt and parse the sender data
|
||||
auto sender_data_keys =
|
||||
KeyScheduleEpoch::sender_data_keys(suite, sender_data_secret, ciphertext);
|
||||
auto sender_data_aad = tls::marshal(SenderDataAAD{
|
||||
group_id,
|
||||
epoch,
|
||||
content_type,
|
||||
});
|
||||
|
||||
auto sender_data_pt = suite.hpke().aead.open(sender_data_keys.key,
|
||||
sender_data_keys.nonce,
|
||||
sender_data_aad,
|
||||
encrypted_sender_data);
|
||||
if (!sender_data_pt) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto sender_data = tls::get<SenderData>(opt::get(sender_data_pt));
|
||||
if (!keys.has_leaf(sender_data.sender)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Decrypt the content
|
||||
auto content_keys = keys.get(content_type,
|
||||
sender_data.sender,
|
||||
sender_data.generation,
|
||||
sender_data.reuse_guard);
|
||||
keys.erase(content_type, sender_data.sender, sender_data.generation);
|
||||
|
||||
auto content_aad = tls::marshal(ContentAAD{
|
||||
group_id,
|
||||
epoch,
|
||||
content_type,
|
||||
authenticated_data,
|
||||
});
|
||||
|
||||
auto content_pt = suite.hpke().aead.open(
|
||||
content_keys.key, content_keys.nonce, content_aad, ciphertext);
|
||||
if (!content_pt) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Parse the content
|
||||
auto content = GroupContent{ group_id,
|
||||
epoch,
|
||||
{ MemberSender{ sender_data.sender } },
|
||||
authenticated_data,
|
||||
content_type };
|
||||
auto auth = GroupContentAuthData{ content_type, {}, {} };
|
||||
|
||||
unmarshal_ciphertext_content(opt::get(content_pt), content, auth);
|
||||
|
||||
return { { AuthenticatedContent{
|
||||
WireFormat::mls_private_message,
|
||||
std::move(content),
|
||||
std::move(auth),
|
||||
} } };
|
||||
}
|
||||
|
||||
PrivateMessage::PrivateMessage(GroupContent content,
|
||||
bytes encrypted_sender_data_in,
|
||||
bytes ciphertext_in)
|
||||
: group_id(std::move(content.group_id))
|
||||
, epoch(content.epoch)
|
||||
, content_type(content.content_type())
|
||||
, authenticated_data(std::move(content.authenticated_data))
|
||||
, encrypted_sender_data(std::move(encrypted_sender_data_in))
|
||||
, ciphertext(std::move(ciphertext_in))
|
||||
{
|
||||
}
|
||||
|
||||
bytes
|
||||
MLSMessage::group_id() const
|
||||
{
|
||||
return var::visit(
|
||||
overloaded{
|
||||
[](const PublicMessage& pt) -> bytes { return pt.get_group_id(); },
|
||||
[](const PrivateMessage& ct) -> bytes { return ct.get_group_id(); },
|
||||
[](const GroupInfo& gi) -> bytes { return gi.group_context.group_id; },
|
||||
[](const auto& /* unused */) -> bytes {
|
||||
throw InvalidParameterError("MLSMessage has no group_id");
|
||||
},
|
||||
},
|
||||
message);
|
||||
}
|
||||
|
||||
epoch_t
|
||||
MLSMessage::epoch() const
|
||||
{
|
||||
return var::visit(
|
||||
overloaded{
|
||||
[](const PublicMessage& pt) -> epoch_t { return pt.get_epoch(); },
|
||||
[](const PrivateMessage& pt) -> epoch_t { return pt.get_epoch(); },
|
||||
[](const auto& /* unused */) -> epoch_t {
|
||||
throw InvalidParameterError("MLSMessage has no epoch");
|
||||
},
|
||||
},
|
||||
message);
|
||||
}
|
||||
|
||||
WireFormat
|
||||
MLSMessage::wire_format() const
|
||||
{
|
||||
return tls::variant<WireFormat>::type(message);
|
||||
}
|
||||
|
||||
MLSMessage::MLSMessage(PublicMessage public_message)
|
||||
: message(std::move(public_message))
|
||||
{
|
||||
}
|
||||
|
||||
MLSMessage::MLSMessage(PrivateMessage private_message)
|
||||
: message(std::move(private_message))
|
||||
{
|
||||
}
|
||||
|
||||
MLSMessage::MLSMessage(Welcome welcome)
|
||||
: message(std::move(welcome))
|
||||
{
|
||||
}
|
||||
|
||||
MLSMessage::MLSMessage(GroupInfo group_info)
|
||||
: message(std::move(group_info))
|
||||
{
|
||||
}
|
||||
|
||||
MLSMessage::MLSMessage(KeyPackage key_package)
|
||||
: message(std::move(key_package))
|
||||
{
|
||||
}
|
||||
|
||||
MLSMessage
|
||||
external_proposal(CipherSuite suite,
|
||||
const bytes& group_id,
|
||||
epoch_t epoch,
|
||||
const Proposal& proposal,
|
||||
uint32_t signer_index,
|
||||
const SignaturePrivateKey& sig_priv)
|
||||
{
|
||||
switch (proposal.proposal_type()) {
|
||||
// These proposal types are OK
|
||||
case ProposalType::add:
|
||||
case ProposalType::remove:
|
||||
case ProposalType::psk:
|
||||
case ProposalType::reinit:
|
||||
case ProposalType::group_context_extensions:
|
||||
break;
|
||||
|
||||
// These proposal types are forbidden
|
||||
case ProposalType::invalid:
|
||||
case ProposalType::update:
|
||||
case ProposalType::external_init:
|
||||
default:
|
||||
throw ProtocolError("External proposal has invalid type");
|
||||
}
|
||||
|
||||
auto content = GroupContent{ group_id,
|
||||
epoch,
|
||||
{ ExternalSenderIndex{ signer_index } },
|
||||
{ /* no authenticated data */ },
|
||||
{ proposal } };
|
||||
auto content_auth = AuthenticatedContent::sign(
|
||||
WireFormat::mls_public_message, std::move(content), suite, sig_priv, {});
|
||||
|
||||
return PublicMessage::protect(std::move(content_auth), suite, {}, {});
|
||||
}
|
||||
|
||||
} // namespace mlspp
|
||||
437
DPP/mlspp/src/session.cpp
Executable file
437
DPP/mlspp/src/session.cpp
Executable file
@@ -0,0 +1,437 @@
|
||||
#include <mls/messages.h>
|
||||
#include <mls/session.h>
|
||||
|
||||
#include <deque>
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
///
|
||||
/// Inner struct declarations for PendingJoin and Session
|
||||
///
|
||||
|
||||
struct PendingJoin::Inner
|
||||
{
|
||||
const CipherSuite suite;
|
||||
const HPKEPrivateKey init_priv;
|
||||
const HPKEPrivateKey leaf_priv;
|
||||
const SignaturePrivateKey sig_priv;
|
||||
const KeyPackage key_package;
|
||||
|
||||
Inner(CipherSuite suite_in,
|
||||
SignaturePrivateKey sig_priv_in,
|
||||
Credential cred_in);
|
||||
|
||||
static PendingJoin create(CipherSuite suite,
|
||||
SignaturePrivateKey sig_priv,
|
||||
Credential cred);
|
||||
};
|
||||
|
||||
struct Session::Inner
|
||||
{
|
||||
std::deque<State> history;
|
||||
std::map<bytes, State> outbound_cache;
|
||||
bool encrypt_handshake{ false };
|
||||
|
||||
explicit Inner(State state);
|
||||
|
||||
static Session begin(CipherSuite suite,
|
||||
const bytes& group_id,
|
||||
const HPKEPrivateKey& leaf_priv,
|
||||
const SignaturePrivateKey& sig_priv,
|
||||
const LeafNode& leaf_node);
|
||||
static Session join(const HPKEPrivateKey& init_priv,
|
||||
const HPKEPrivateKey& leaf_priv,
|
||||
const SignaturePrivateKey& sig_priv,
|
||||
const KeyPackage& key_package,
|
||||
const bytes& welcome_data);
|
||||
|
||||
bytes fresh_secret() const;
|
||||
MLSMessage import_handshake(const bytes& encoded) const;
|
||||
State& for_epoch(epoch_t epoch);
|
||||
};
|
||||
|
||||
///
|
||||
/// Client
|
||||
///
|
||||
|
||||
Client::Client(CipherSuite suite_in,
|
||||
SignaturePrivateKey sig_priv_in,
|
||||
Credential cred_in)
|
||||
: suite(suite_in)
|
||||
, sig_priv(std::move(sig_priv_in))
|
||||
, cred(std::move(cred_in))
|
||||
{
|
||||
}
|
||||
|
||||
Session
|
||||
Client::begin_session(const bytes& group_id) const
|
||||
{
|
||||
auto leaf_priv = HPKEPrivateKey::generate(suite);
|
||||
auto leaf_node = LeafNode(suite,
|
||||
leaf_priv.public_key,
|
||||
sig_priv.public_key,
|
||||
cred,
|
||||
Capabilities::create_default(),
|
||||
Lifetime::create_default(),
|
||||
{},
|
||||
sig_priv);
|
||||
|
||||
return Session::Inner::begin(suite, group_id, leaf_priv, sig_priv, leaf_node);
|
||||
}
|
||||
|
||||
PendingJoin
|
||||
Client::start_join() const
|
||||
{
|
||||
return PendingJoin::Inner::create(suite, sig_priv, cred);
|
||||
}
|
||||
|
||||
///
|
||||
/// PendingJoin
|
||||
///
|
||||
|
||||
PendingJoin::Inner::Inner(CipherSuite suite_in,
|
||||
SignaturePrivateKey sig_priv_in,
|
||||
Credential cred_in)
|
||||
: suite(suite_in)
|
||||
, init_priv(HPKEPrivateKey::generate(suite))
|
||||
, leaf_priv(HPKEPrivateKey::generate(suite))
|
||||
, sig_priv(std::move(sig_priv_in))
|
||||
, key_package(suite,
|
||||
init_priv.public_key,
|
||||
LeafNode(suite,
|
||||
leaf_priv.public_key,
|
||||
sig_priv.public_key,
|
||||
std::move(cred_in),
|
||||
Capabilities::create_default(),
|
||||
Lifetime::create_default(),
|
||||
{},
|
||||
sig_priv),
|
||||
{},
|
||||
sig_priv)
|
||||
{
|
||||
}
|
||||
|
||||
PendingJoin
|
||||
PendingJoin::Inner::create(CipherSuite suite,
|
||||
SignaturePrivateKey sig_priv,
|
||||
Credential cred)
|
||||
{
|
||||
auto inner =
|
||||
std::make_unique<Inner>(suite, std::move(sig_priv), std::move(cred));
|
||||
return { inner.release() };
|
||||
}
|
||||
|
||||
PendingJoin::PendingJoin(PendingJoin&& other) noexcept = default;
|
||||
|
||||
PendingJoin&
|
||||
PendingJoin::operator=(PendingJoin&& other) noexcept = default;
|
||||
|
||||
PendingJoin::~PendingJoin() = default;
|
||||
|
||||
PendingJoin::PendingJoin(Inner* inner_in)
|
||||
: inner(inner_in)
|
||||
{
|
||||
}
|
||||
|
||||
bytes
|
||||
PendingJoin::key_package() const
|
||||
{
|
||||
return tls::marshal(inner->key_package);
|
||||
}
|
||||
|
||||
Session
|
||||
PendingJoin::complete(const bytes& welcome) const
|
||||
{
|
||||
return Session::Inner::join(inner->init_priv,
|
||||
inner->leaf_priv,
|
||||
inner->sig_priv,
|
||||
inner->key_package,
|
||||
welcome);
|
||||
}
|
||||
|
||||
///
|
||||
/// Session
|
||||
///
|
||||
|
||||
Session::Inner::Inner(State state)
|
||||
: history{ std::move(state) }
|
||||
, encrypt_handshake(true)
|
||||
{
|
||||
}
|
||||
|
||||
Session
|
||||
Session::Inner::begin(CipherSuite suite,
|
||||
const bytes& group_id,
|
||||
const HPKEPrivateKey& leaf_priv,
|
||||
const SignaturePrivateKey& sig_priv,
|
||||
const LeafNode& leaf_node)
|
||||
{
|
||||
auto state = State(group_id, suite, leaf_priv, sig_priv, leaf_node, {});
|
||||
auto inner = std::make_unique<Inner>(state);
|
||||
return { inner.release() };
|
||||
}
|
||||
|
||||
Session
|
||||
Session::Inner::join(const HPKEPrivateKey& init_priv,
|
||||
const HPKEPrivateKey& leaf_priv,
|
||||
const SignaturePrivateKey& sig_priv,
|
||||
const KeyPackage& key_package,
|
||||
const bytes& welcome_data)
|
||||
{
|
||||
auto welcome = tls::get<Welcome>(welcome_data);
|
||||
|
||||
auto state = State(
|
||||
init_priv, leaf_priv, sig_priv, key_package, welcome, std::nullopt, {});
|
||||
auto inner = std::make_unique<Inner>(state);
|
||||
return { inner.release() };
|
||||
}
|
||||
|
||||
bytes
|
||||
Session::Inner::fresh_secret() const
|
||||
{
|
||||
const auto suite = history.front().cipher_suite();
|
||||
return random_bytes(suite.secret_size());
|
||||
}
|
||||
|
||||
MLSMessage
|
||||
Session::Inner::import_handshake(const bytes& encoded) const
|
||||
{
|
||||
auto msg = tls::get<MLSMessage>(encoded);
|
||||
|
||||
switch (msg.wire_format()) {
|
||||
case WireFormat::mls_public_message:
|
||||
if (encrypt_handshake) {
|
||||
throw ProtocolError("Handshake not encrypted as required");
|
||||
}
|
||||
|
||||
return msg;
|
||||
|
||||
case WireFormat::mls_private_message: {
|
||||
if (!encrypt_handshake) {
|
||||
throw ProtocolError("Unexpected handshake encryption");
|
||||
}
|
||||
|
||||
return msg;
|
||||
}
|
||||
|
||||
default:
|
||||
throw InvalidParameterError("Illegal wire format");
|
||||
}
|
||||
}
|
||||
|
||||
State&
|
||||
Session::Inner::for_epoch(epoch_t epoch)
|
||||
{
|
||||
for (auto& state : history) {
|
||||
if (state.epoch() == epoch) {
|
||||
return state;
|
||||
}
|
||||
}
|
||||
|
||||
throw MissingStateError("No state for epoch");
|
||||
}
|
||||
|
||||
Session::Session(Session&& other) noexcept = default;
|
||||
|
||||
Session&
|
||||
Session::operator=(Session&& other) noexcept = default;
|
||||
|
||||
Session::~Session() = default;
|
||||
|
||||
Session::Session(Inner* inner_in)
|
||||
: inner(inner_in)
|
||||
{
|
||||
}
|
||||
|
||||
void
|
||||
Session::encrypt_handshake(bool enabled)
|
||||
{
|
||||
inner->encrypt_handshake = enabled;
|
||||
}
|
||||
|
||||
bytes
|
||||
Session::add(const bytes& key_package_data)
|
||||
{
|
||||
auto key_package = tls::get<KeyPackage>(key_package_data);
|
||||
auto proposal = inner->history.front().add(
|
||||
key_package, { inner->encrypt_handshake, {}, 0 });
|
||||
return tls::marshal(proposal);
|
||||
}
|
||||
|
||||
bytes
|
||||
Session::update()
|
||||
{
|
||||
auto leaf_secret = inner->fresh_secret();
|
||||
|
||||
auto leaf_priv = HPKEPrivateKey::generate(cipher_suite());
|
||||
auto proposal = inner->history.front().update(
|
||||
std::move(leaf_priv), {}, { inner->encrypt_handshake, {}, 0 });
|
||||
return tls::marshal(proposal);
|
||||
}
|
||||
|
||||
bytes
|
||||
Session::remove(uint32_t index)
|
||||
{
|
||||
auto proposal = inner->history.front().remove(
|
||||
RosterIndex{ index }, { inner->encrypt_handshake, {}, 0 });
|
||||
return tls::marshal(proposal);
|
||||
}
|
||||
|
||||
std::tuple<bytes, bytes>
|
||||
Session::commit(const bytes& proposal)
|
||||
{
|
||||
return commit(std::vector<bytes>{ proposal });
|
||||
}
|
||||
|
||||
std::tuple<bytes, bytes>
|
||||
Session::commit(const std::vector<bytes>& proposals)
|
||||
{
|
||||
auto provisional_state = inner->history.front();
|
||||
for (const auto& proposal_data : proposals) {
|
||||
auto msg = inner->import_handshake(proposal_data);
|
||||
auto maybe_state = provisional_state.handle(msg);
|
||||
if (maybe_state) {
|
||||
throw InvalidParameterError("Invalid proposal; actually a commit");
|
||||
}
|
||||
}
|
||||
|
||||
inner->history.front() = std::move(provisional_state);
|
||||
return commit();
|
||||
}
|
||||
|
||||
std::tuple<bytes, bytes>
|
||||
Session::commit()
|
||||
{
|
||||
auto commit_secret = inner->fresh_secret();
|
||||
auto encrypt = inner->encrypt_handshake;
|
||||
auto [commit, welcome, new_state] = inner->history.front().commit(
|
||||
commit_secret, CommitOpts{ {}, true, encrypt, {} }, { encrypt, {}, 0 });
|
||||
|
||||
auto commit_msg = tls::marshal(commit);
|
||||
auto welcome_msg = tls::marshal(welcome);
|
||||
|
||||
inner->outbound_cache.insert({ commit_msg, new_state });
|
||||
return std::make_tuple(welcome_msg, commit_msg);
|
||||
}
|
||||
|
||||
bool
|
||||
Session::handle(const bytes& handshake_data)
|
||||
{
|
||||
auto msg = inner->import_handshake(handshake_data);
|
||||
|
||||
auto maybe_cached_state = std::optional<State>{};
|
||||
auto node = inner->outbound_cache.extract(handshake_data);
|
||||
if (!node.empty()) {
|
||||
maybe_cached_state = node.mapped();
|
||||
}
|
||||
|
||||
auto maybe_next_state =
|
||||
inner->history.front().handle(msg, maybe_cached_state);
|
||||
if (!maybe_next_state) {
|
||||
return false;
|
||||
}
|
||||
|
||||
inner->history.emplace_front(opt::get(maybe_next_state));
|
||||
return true;
|
||||
}
|
||||
|
||||
epoch_t
|
||||
Session::epoch() const
|
||||
{
|
||||
return inner->history.front().epoch();
|
||||
}
|
||||
|
||||
LeafIndex
|
||||
Session::index() const
|
||||
{
|
||||
return inner->history.front().index();
|
||||
}
|
||||
|
||||
CipherSuite
|
||||
Session::cipher_suite() const
|
||||
{
|
||||
return inner->history.front().cipher_suite();
|
||||
}
|
||||
|
||||
const ExtensionList&
|
||||
Session::extensions() const
|
||||
{
|
||||
return inner->history.front().extensions();
|
||||
}
|
||||
|
||||
const TreeKEMPublicKey&
|
||||
Session::tree() const
|
||||
{
|
||||
return inner->history.front().tree();
|
||||
}
|
||||
|
||||
bytes
|
||||
Session::do_export(const std::string& label,
|
||||
const bytes& context,
|
||||
size_t size) const
|
||||
{
|
||||
return inner->history.front().do_export(label, context, size);
|
||||
}
|
||||
|
||||
GroupInfo
|
||||
Session::group_info() const
|
||||
{
|
||||
return inner->history.front().group_info(true);
|
||||
}
|
||||
|
||||
std::vector<LeafNode>
|
||||
Session::roster() const
|
||||
{
|
||||
return inner->history.front().roster();
|
||||
}
|
||||
|
||||
bytes
|
||||
Session::epoch_authenticator() const
|
||||
{
|
||||
return inner->history.front().epoch_authenticator();
|
||||
}
|
||||
|
||||
bytes
|
||||
Session::protect(const bytes& plaintext)
|
||||
{
|
||||
auto msg = inner->history.front().protect({}, plaintext, 0);
|
||||
return tls::marshal(msg);
|
||||
}
|
||||
|
||||
// TODO(rlb@ipv.sx): It would be good to expose identity information
|
||||
// here, since ciphertexts are authenticated per sender. Who sent
|
||||
// this ciphertext?
|
||||
bytes
|
||||
Session::unprotect(const bytes& ciphertext)
|
||||
{
|
||||
auto ciphertext_obj = tls::get<MLSMessage>(ciphertext);
|
||||
auto& state = inner->for_epoch(ciphertext_obj.epoch());
|
||||
auto [aad, pt] = state.unprotect(ciphertext_obj);
|
||||
silence_unused(aad);
|
||||
return pt;
|
||||
}
|
||||
|
||||
bool
|
||||
operator==(const Session& lhs, const Session& rhs)
|
||||
{
|
||||
if (lhs.inner->encrypt_handshake != rhs.inner->encrypt_handshake) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto size = std::min(lhs.inner->history.size(), rhs.inner->history.size());
|
||||
for (size_t i = 0; i < size; i += 1) {
|
||||
if (lhs.inner->history.at(i) != rhs.inner->history.at(i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
operator!=(const Session& lhs, const Session& rhs)
|
||||
{
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
} // namespace mlspp
|
||||
2219
DPP/mlspp/src/state.cpp
Executable file
2219
DPP/mlspp/src/state.cpp
Executable file
File diff suppressed because it is too large
Load Diff
223
DPP/mlspp/src/tree_math.cpp
Executable file
223
DPP/mlspp/src/tree_math.cpp
Executable file
@@ -0,0 +1,223 @@
|
||||
#include "mls/tree_math.h"
|
||||
#include "mls/common.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
static const uint32_t one = 0x01;
|
||||
|
||||
static uint32_t
|
||||
log2(uint32_t x)
|
||||
{
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t k = 0;
|
||||
while ((x >> k) > 0) {
|
||||
k += 1;
|
||||
}
|
||||
return k - 1;
|
||||
}
|
||||
|
||||
namespace mlspp {
|
||||
|
||||
LeafCount::LeafCount(const NodeCount w)
|
||||
{
|
||||
if (w.val == 0) {
|
||||
val = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
if ((w.val & one) == 0) {
|
||||
throw InvalidParameterError("Only odd node counts describe trees");
|
||||
}
|
||||
|
||||
val = (w.val >> one) + 1;
|
||||
}
|
||||
|
||||
LeafCount
|
||||
LeafCount::full(const LeafCount n)
|
||||
{
|
||||
auto w = uint32_t(1);
|
||||
while (w < n.val) {
|
||||
w <<= 1U;
|
||||
}
|
||||
return LeafCount{ w };
|
||||
}
|
||||
|
||||
NodeCount::NodeCount(const LeafCount n)
|
||||
: UInt32(2 * (n.val - 1) + 1)
|
||||
{
|
||||
}
|
||||
|
||||
LeafIndex::LeafIndex(NodeIndex x)
|
||||
: UInt32(0)
|
||||
{
|
||||
if (x.val % 2 == 1) {
|
||||
throw InvalidParameterError("Only even node indices describe leaves");
|
||||
}
|
||||
|
||||
val = x.val >> 1; // NOLINT(hicpp-signed-bitwise)
|
||||
}
|
||||
|
||||
NodeIndex
|
||||
LeafIndex::ancestor(LeafIndex other) const
|
||||
{
|
||||
auto ln = NodeIndex(*this);
|
||||
auto rn = NodeIndex(other);
|
||||
if (ln == rn) {
|
||||
return ln;
|
||||
}
|
||||
|
||||
uint8_t k = 0;
|
||||
while (ln != rn) {
|
||||
ln.val = ln.val >> 1U;
|
||||
rn.val = rn.val >> 1U;
|
||||
k += 1;
|
||||
}
|
||||
|
||||
const uint32_t prefix = ln.val << k;
|
||||
const uint32_t stop = (1U << uint8_t(k - 1));
|
||||
return NodeIndex{ prefix + (stop - 1) };
|
||||
}
|
||||
|
||||
NodeIndex::NodeIndex(LeafIndex x)
|
||||
: UInt32(2 * x.val)
|
||||
{
|
||||
}
|
||||
|
||||
NodeIndex
|
||||
NodeIndex::root(LeafCount n)
|
||||
{
|
||||
if (n.val == 0) {
|
||||
throw std::runtime_error("Root for zero-size tree is undefined");
|
||||
}
|
||||
|
||||
auto w = NodeCount(n);
|
||||
return NodeIndex{ (one << log2(w.val)) - 1 };
|
||||
}
|
||||
|
||||
bool
|
||||
NodeIndex::is_leaf() const
|
||||
{
|
||||
return val % 2 == 0;
|
||||
}
|
||||
|
||||
bool
|
||||
NodeIndex::is_below(NodeIndex other) const
|
||||
{
|
||||
auto lx = level();
|
||||
auto ly = other.level();
|
||||
return lx <= ly && (val >> (ly + 1) == other.val >> (ly + 1));
|
||||
}
|
||||
|
||||
NodeIndex
|
||||
NodeIndex::left() const
|
||||
{
|
||||
if (is_leaf()) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
// The clang analyzer doesn't realize that is_leaf() assures that level >= 1
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
|
||||
return NodeIndex{ val ^ (one << (level() - 1)) };
|
||||
}
|
||||
|
||||
NodeIndex
|
||||
NodeIndex::right() const
|
||||
{
|
||||
if (is_leaf()) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
return NodeIndex{ val ^ (uint32_t(0x03) << (level() - 1)) };
|
||||
}
|
||||
|
||||
NodeIndex
|
||||
NodeIndex::parent() const
|
||||
{
|
||||
auto k = level();
|
||||
return NodeIndex{ (val | (one << k)) & ~(one << (k + 1)) };
|
||||
}
|
||||
|
||||
NodeIndex
|
||||
NodeIndex::sibling() const
|
||||
{
|
||||
return sibling(parent());
|
||||
}
|
||||
|
||||
NodeIndex
|
||||
NodeIndex::sibling(NodeIndex ancestor) const
|
||||
{
|
||||
if (!is_below(ancestor)) {
|
||||
throw InvalidParameterError("Node is not below claimed ancestor");
|
||||
}
|
||||
|
||||
auto l = ancestor.left();
|
||||
auto r = ancestor.right();
|
||||
|
||||
if (is_below(l)) {
|
||||
return r;
|
||||
}
|
||||
|
||||
return l;
|
||||
}
|
||||
|
||||
std::vector<NodeIndex>
|
||||
NodeIndex::dirpath(LeafCount n)
|
||||
{
|
||||
if (val >= NodeCount(n).val) {
|
||||
throw InvalidParameterError("Request for dirpath outside of tree");
|
||||
}
|
||||
|
||||
auto d = std::vector<NodeIndex>{};
|
||||
|
||||
auto r = root(n);
|
||||
if (*this == r) {
|
||||
return d;
|
||||
}
|
||||
|
||||
auto p = parent();
|
||||
while (p.val != r.val) {
|
||||
d.push_back(p);
|
||||
p = p.parent();
|
||||
}
|
||||
|
||||
// Include the root except in a one-member tree
|
||||
if (val != r.val) {
|
||||
d.push_back(p);
|
||||
}
|
||||
|
||||
return d;
|
||||
}
|
||||
|
||||
std::vector<NodeIndex>
|
||||
NodeIndex::copath(LeafCount n)
|
||||
{
|
||||
auto d = dirpath(n);
|
||||
if (d.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Prepend leaf; omit root
|
||||
d.insert(d.begin(), *this);
|
||||
d.pop_back();
|
||||
|
||||
return stdx::transform<NodeIndex>(d, [](auto x) { return x.sibling(); });
|
||||
}
|
||||
|
||||
uint32_t
|
||||
NodeIndex::level() const
|
||||
{
|
||||
if ((val & one) == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t k = 0;
|
||||
while (((val >> k) & one) == 1) {
|
||||
k += 1;
|
||||
}
|
||||
return k;
|
||||
}
|
||||
|
||||
} // namespace mlspp
|
||||
1127
DPP/mlspp/src/treekem.cpp
Executable file
1127
DPP/mlspp/src/treekem.cpp
Executable file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user