#pragma once #include #include #include #include #include #include #include #include #include namespace mlspp::tls { // For indicating no min or max in vector definitions const size_t none = std::numeric_limits::max(); class WriteError : public std::invalid_argument { public: using parent = std::invalid_argument; using parent::parent; }; class ReadError : public std::invalid_argument { public: using parent = std::invalid_argument; using parent::parent; }; /// /// Declarations of Streams and Traits /// class ostream { public: static const size_t none = std::numeric_limits::max(); void write_raw(const std::vector& bytes); const std::vector& bytes() const { return _buffer; } size_t size() const { return _buffer.size(); } bool empty() const { return _buffer.empty(); } private: std::vector _buffer; ostream& write_uint(uint64_t value, int length); friend ostream& operator<<(ostream& out, bool data); friend ostream& operator<<(ostream& out, uint8_t data); friend ostream& operator<<(ostream& out, uint16_t data); friend ostream& operator<<(ostream& out, uint32_t data); friend ostream& operator<<(ostream& out, uint64_t data); template friend ostream& operator<<(ostream& out, const std::vector& data); friend struct varint; }; class istream { public: istream(const std::vector& data) : _buffer(data) { // So that we can use the constant-time pop_back std::reverse(_buffer.begin(), _buffer.end()); } size_t size() const { return _buffer.size(); } bool empty() const { return _buffer.empty(); } std::vector bytes() { auto bytes = _buffer; std::reverse(bytes.begin(), bytes.end()); return bytes; } private: istream() {} std::vector _buffer; uint8_t next(); template istream& read_uint(T& data, size_t length) { uint64_t value = 0; for (size_t i = 0; i < length; i += 1) { value = (value << unsigned(8)) + next(); } data = static_cast(value); return *this; } friend istream& operator>>(istream& in, bool& data); friend istream& operator>>(istream& in, uint8_t& data); friend istream& operator>>(istream& in, uint16_t& data); friend istream& operator>>(istream& in, uint32_t& data); friend istream& operator>>(istream& in, uint64_t& data); template friend istream& operator>>(istream& in, std::vector& data); friend struct varint; }; // Traits must have static encode and decode methods, of the following form: // // static ostream& encode(ostream& str, const T& val); // static istream& decode(istream& str, T& val); // // Trait types will never be constructed; only these static methods are used. // The value arguments to encode and decode can be as strict or as loose as // desired. // // Ultimately, all interesting encoding should be done through traits. // // * vectors // * variants // * varints struct pass { template static ostream& encode(ostream& str, const T& val); template static istream& decode(istream& str, T& val); }; template struct variant { template static inline Ts type(const var::variant& data); template static ostream& encode(ostream& str, const var::variant& data); template static inline typename std::enable_if::type read_variant(istream&, Te, var::variant&); template static inline typename std::enable_if < I::type read_variant(istream& str, Te target_type, var::variant& v); template static istream& decode(istream& str, var::variant& data); }; struct varint { static ostream& encode(ostream& str, const uint64_t& val); static istream& decode(istream& str, uint64_t& val); }; /// /// Writer implementations /// // Primitive writers defined in .cpp file // Array writer template ostream& operator<<(ostream& out, const std::array& data) { for (const auto& item : data) { out << item; } return out; } // Optional writer template ostream& operator<<(ostream& out, const std::optional& opt) { if (!opt) { return out << uint8_t(0); } return out << uint8_t(1) << opt::get(opt); } // Enum writer template::value, int> = 0> ostream& operator<<(ostream& str, const T& val) { auto u = static_cast>(val); return str << u; } // Vector writer template ostream& operator<<(ostream& str, const std::vector& vec) { // Pre-encode contents ostream temp; for (const auto& item : vec) { temp << item; } // Write the encoded length, then the pre-encoded data varint::encode(str, temp._buffer.size()); str.write_raw(temp.bytes()); return str; } /// /// Reader implementations /// // Primitive type readers defined in .cpp file // Array reader template istream& operator>>(istream& in, std::array& data) { for (auto& item : data) { in >> item; } return in; } // Optional reader template istream& operator>>(istream& in, std::optional& opt) { uint8_t present = 0; in >> present; switch (present) { case 0: opt.reset(); return in; case 1: opt.emplace(); return in >> opt::get(opt); default: throw std::invalid_argument("Malformed optional"); } } // Enum reader // XXX(rlb): It would be nice if this could enforce that the values are valid, // but C++ doesn't seem to have that ability. When used as a tag for variants, // the variant reader will enforce, at least. template::value, int> = 0> istream& operator>>(istream& str, T& val) { std::underlying_type_t u; str >> u; val = static_cast(u); return str; } // Vector reader template istream& operator>>(istream& str, std::vector& vec) { // Read the encoded data size auto size = uint64_t(0); varint::decode(str, size); if (size > str._buffer.size()) { throw ReadError("Vector is longer than remaining data"); } // Read the elements of the vector // NB: Remember that we store the vector in reverse order // NB: This requires that T be default-constructible istream r; r._buffer = std::vector{ str._buffer.end() - size, str._buffer.end() }; vec.clear(); while (r._buffer.size() > 0) { vec.emplace_back(); r >> vec.back(); } // Truncate the primary buffer str._buffer.erase(str._buffer.end() - size, str._buffer.end()); return str; } // Abbreviations template std::vector marshal(const T& value) { ostream w; w << value; return w.bytes(); } template void unmarshal(const std::vector& data, T& value) { istream r(data); r >> value; } template T get(const std::vector& data, Tp... args) { T value(args...); unmarshal(data, value); return value; } // Use this macro to define struct serialization with minimal boilerplate #define TLS_SERIALIZABLE(...) \ static const bool _tls_serializable = true; \ auto _tls_fields_r() \ { \ return std::forward_as_tuple(__VA_ARGS__); \ } \ auto _tls_fields_w() const \ { \ return std::forward_as_tuple(__VA_ARGS__); \ } // If your struct contains nontrivial members (e.g., vectors), use this to // define traits for them. #define TLS_TRAITS(...) \ static const bool _tls_has_traits = true; \ using _tls_traits = std::tuple<__VA_ARGS__>; template struct is_serializable { template static std::true_type test(decltype(U::_tls_serializable)); template static std::false_type test(...); static const bool value = decltype(test(true))::value; }; template struct has_traits { template static std::true_type test(decltype(U::_tls_has_traits)); template static std::false_type test(...); static const bool value = decltype(test(true))::value; }; /// /// Trait implementations /// // Pass-through (normal encoding/decoding) template ostream& pass::encode(ostream& str, const T& val) { return str << val; } template istream& pass::decode(istream& str, T& val) { return str >> val; } // Variant encoding template constexpr Ts variant_map(); #define TLS_VARIANT_MAP(EnumType, MappedType, enum_value) \ template<> \ constexpr EnumType variant_map() \ { \ return EnumType::enum_value; \ } template template inline Ts variant::type(const var::variant& data) { const auto get_type = [](const auto& v) { return variant_map>(); }; return var::visit(get_type, data); } template template ostream& variant::encode(ostream& str, const var::variant& data) { const auto write_variant = [&str](auto&& v) { using Tv = std::decay_t; str << variant_map() << v; }; var::visit(write_variant, data); return str; } template template inline typename std::enable_if::type variant::read_variant(istream&, Te, var::variant&) { throw ReadError("Invalid variant type label"); } template template inline typename std::enable_if < I::type variant::read_variant(istream& str, Te target_type, var::variant& v) { using Tc = var::variant_alternative_t>; if (variant_map() == target_type) { str >> v.template emplace(); return; } read_variant(str, target_type, v); } template template istream& variant::decode(istream& str, var::variant& data) { Ts target_type; str >> target_type; read_variant(str, target_type, data); return str; } // Struct writer without traits (enabled by macro) template inline typename std::enable_if::type write_tuple(ostream&, const std::tuple&) { } template inline typename std::enable_if < I::type write_tuple(ostream& str, const std::tuple& t) { str << std::get(t); write_tuple(str, t); } template inline typename std::enable_if::value && !has_traits::value, ostream&>::type operator<<(ostream& str, const T& obj) { write_tuple(str, obj._tls_fields_w()); return str; } // Struct writer with traits (enabled by macro) template inline typename std::enable_if::type write_tuple_traits(ostream&, const std::tuple&) { } template inline typename std::enable_if < I::type write_tuple_traits(ostream& str, const std::tuple& t) { std::tuple_element_t::encode(str, std::get(t)); write_tuple_traits(str, t); } template inline typename std::enable_if::value && has_traits::value, ostream&>::type operator<<(ostream& str, const T& obj) { write_tuple_traits(str, obj._tls_fields_w()); return str; } // Struct reader without traits (enabled by macro) template inline typename std::enable_if::type read_tuple(istream&, const std::tuple&) { } template inline typename std::enable_if < I::type read_tuple(istream& str, const std::tuple& t) { str >> std::get(t); read_tuple(str, t); } template inline typename std::enable_if::value && !has_traits::value, istream&>::type operator>>(istream& str, T& obj) { read_tuple(str, obj._tls_fields_r()); return str; } // Struct reader with traits (enabled by macro) template inline typename std::enable_if::type read_tuple_traits(istream&, const std::tuple&) { } template inline typename std::enable_if < I::type read_tuple_traits(istream& str, const std::tuple& t) { std::tuple_element_t::decode(str, std::get(t)); read_tuple_traits(str, t); } template inline typename std::enable_if::value && has_traits::value, istream&>::type operator>>(istream& str, T& obj) { read_tuple_traits(str, obj._tls_fields_r()); return str; } } // namespace mlspp::tls