This commit is contained in:
2024-10-10 19:05:48 +00:00
commit cffdcba6af
1880 changed files with 813614 additions and 0 deletions

86
td/tdnet/CMakeLists.txt Normal file
View File

@@ -0,0 +1,86 @@
if ((CMAKE_MAJOR_VERSION LESS 3) OR (CMAKE_VERSION VERSION_LESS "3.0.2"))
message(FATAL_ERROR "CMake >= 3.0.2 is required")
endif()
if (NOT DEFINED CMAKE_INSTALL_LIBDIR)
set(CMAKE_INSTALL_LIBDIR "lib")
endif()
if (NOT OPENSSL_FOUND)
find_package(OpenSSL REQUIRED)
find_package(ZLIB REQUIRED)
endif()
set(TDNET_SOURCE
td/net/GetHostByNameActor.cpp
td/net/HttpChunkedByteFlow.cpp
td/net/HttpConnectionBase.cpp
td/net/HttpContentLengthByteFlow.cpp
td/net/HttpFile.cpp
td/net/HttpInboundConnection.cpp
td/net/HttpOutboundConnection.cpp
td/net/HttpProxy.cpp
td/net/HttpQuery.cpp
td/net/HttpReader.cpp
td/net/Socks5.cpp
td/net/SslCtx.cpp
td/net/SslStream.cpp
td/net/TcpListener.cpp
td/net/TransparentProxy.cpp
td/net/Wget.cpp
td/net/GetHostByNameActor.h
td/net/HttpChunkedByteFlow.h
td/net/HttpConnectionBase.h
td/net/HttpContentLengthByteFlow.h
td/net/HttpFile.h
td/net/HttpHeaderCreator.h
td/net/HttpInboundConnection.h
td/net/HttpOutboundConnection.h
td/net/HttpProxy.h
td/net/HttpQuery.h
td/net/HttpReader.h
td/net/NetStats.h
td/net/Socks5.h
td/net/SslCtx.h
td/net/SslStream.h
td/net/TcpListener.h
td/net/TransparentProxy.h
td/net/Wget.h
)
if (APPLE_WATCH)
set(TDNET_SOURCE
${TDNET_SOURCE}
td/net/DarwinHttp.mm
td/net/DarwinHttp.h
)
set_source_files_properties(td/net/DarwinHttp.mm PROPERTIES COMPILE_FLAGS -fobjc-arc)
endif()
add_library(tdnet STATIC ${TDNET_SOURCE})
target_include_directories(tdnet PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>)
target_include_directories(tdnet SYSTEM PRIVATE $<BUILD_INTERFACE:${OPENSSL_INCLUDE_DIR}>)
target_link_libraries(tdnet PUBLIC tdutils tdactor)
if (NOT EMSCRIPTEN)
target_link_libraries(tdnet PRIVATE ${OPENSSL_SSL_LIBRARY})
endif()
target_link_libraries(tdnet PRIVATE ${OPENSSL_CRYPTO_LIBRARY} ${CMAKE_DL_LIBS} ${ZLIB_LIBRARIES})
if (WIN32)
if (MINGW)
target_link_libraries(tdnet PRIVATE ws2_32 mswsock crypt32)
else()
target_link_libraries(tdnet PRIVATE ws2_32 Mswsock Crypt32)
endif()
endif()
if (APPLE_WATCH)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
target_link_libraries(tdnet PRIVATE ${FOUNDATION_LIBRARY})
endif()
install(TARGETS tdnet EXPORT TdStaticTargets
LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}"
ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}"
)

View File

@@ -0,0 +1,21 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/buffer.h"
#include "td/utils/Promise.h"
#include "td/utils/Slice.h"
namespace td {
class DarwinHttp {
public:
static void get(CSlice url, Promise<BufferSlice> promise);
static void post(CSlice url, Slice data, Promise<BufferSlice> promise);
};
} // namespace td

View File

@@ -0,0 +1,80 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/DarwinHttp.h"
#include "td/utils/logging.h"
#include "td/utils/SliceBuilder.h"
#import <Foundation/Foundation.h>
namespace td {
namespace {
NSURLSession *getSession() {
static NSURLSession *urlSession = [] {
auto configuration = [NSURLSessionConfiguration defaultSessionConfiguration];
configuration.networkServiceType = NSURLNetworkServiceTypeResponsiveData;
configuration.timeoutIntervalForResource = 90;
configuration.waitsForConnectivity = true;
return [NSURLSession sessionWithConfiguration:configuration];
}();
return urlSession;
}
NSString *to_ns_string(CSlice slice) {
return [NSString stringWithUTF8String:slice.c_str()];
}
NSData *to_ns_data(Slice data) {
return [NSData dataWithBytes:static_cast<const void *>(data.data()) length:data.size()];
}
auto http_get(CSlice url) {
auto nsurl = [NSURL URLWithString:to_ns_string(url)];
auto request = [NSURLRequest requestWithURL:nsurl];
return request;
}
auto http_post(CSlice url, Slice data) {
auto nsurl = [NSURL URLWithString:to_ns_string(url)];
auto request = [NSMutableURLRequest requestWithURL:nsurl];
[request setHTTPMethod:@"POST"];
[request setHTTPBody:to_ns_data(data)];
[request setValue:@"keep-alive" forHTTPHeaderField:@"Connection"];
[request setValue:@"" forHTTPHeaderField:@"Host"];
[request setValue:to_ns_string(PSLICE() << data.size()) forHTTPHeaderField:@"Content-Length"];
[request setValue:@"application/x-www-form-urlencoded" forHTTPHeaderField:@"Content-Type"];
return request;
}
void http_send(NSURLRequest *request, Promise<BufferSlice> promise) {
__block auto callback = std::move(promise);
NSURLSessionDataTask *dataTask =
[getSession()
dataTaskWithRequest:request
completionHandler:
^(NSData *data, NSURLResponse *response, NSError *error) {
if (error == nil) {
callback.set_value(BufferSlice(Slice((const char *)([data bytes]), [data length])));
} else {
callback.set_error(Status::Error(static_cast<int32>([error code]), "HTTP request failed"));
}
}];
[dataTask resume];
}
} // namespace
void DarwinHttp::get(CSlice url, Promise<BufferSlice> promise) {
return http_send(http_get(url), std::move(promise));
}
void DarwinHttp::post(CSlice url, Slice data, Promise<BufferSlice> promise) {
return http_send(http_post(url, data), std::move(promise));
}
} // namespace td

View File

@@ -0,0 +1,217 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/GetHostByNameActor.h"
#include "td/net/HttpQuery.h"
#include "td/net/SslCtx.h"
#include "td/net/Wget.h"
#include "td/utils/common.h"
#include "td/utils/JsonBuilder.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/Slice.h"
#include "td/utils/SliceBuilder.h"
#include "td/utils/Time.h"
namespace td {
namespace detail {
class GoogleDnsResolver final : public Actor {
public:
GoogleDnsResolver(std::string host, bool prefer_ipv6, Promise<IPAddress> promise)
: host_(std::move(host)), prefer_ipv6_(prefer_ipv6), promise_(std::move(promise)) {
}
private:
std::string host_;
bool prefer_ipv6_;
Promise<IPAddress> promise_;
ActorOwn<Wget> wget_;
double begin_time_ = 0;
void start_up() final {
auto r_address = IPAddress::get_ip_address(host_);
if (r_address.is_ok()) {
promise_.set_value(r_address.move_as_ok());
return stop();
}
const int timeout = 10;
const int ttl = 3;
begin_time_ = Time::now();
auto wget_promise = PromiseCreator::lambda([actor_id = actor_id(this)](Result<unique_ptr<HttpQuery>> r_http_query) {
send_closure(actor_id, &GoogleDnsResolver::on_result, std::move(r_http_query));
});
wget_ = create_actor<Wget>(
"GoogleDnsResolver", std::move(wget_promise),
PSTRING() << "https://dns.google/resolve?name=" << url_encode(host_) << "&type=" << (prefer_ipv6_ ? 28 : 1),
std::vector<std::pair<string, string>>({{"Host", "dns.google"}}), timeout, ttl, prefer_ipv6_,
SslCtx::VerifyPeer::Off);
}
static Result<IPAddress> get_ip_address(Result<unique_ptr<HttpQuery>> r_http_query) {
TRY_RESULT(http_query, std::move(r_http_query));
auto get_ip_address = [](JsonValue &answer) -> Result<IPAddress> {
auto &array = answer.get_array();
if (array.empty()) {
return Status::Error("Failed to parse DNS result: Answer is an empty array");
}
if (array[0].type() != JsonValue::Type::Object) {
return Status::Error("Failed to parse DNS result: Answer[0] is not an object");
}
auto &answer_0 = array[0].get_object();
TRY_RESULT(ip_str, answer_0.get_required_string_field("data"));
IPAddress ip;
TRY_STATUS(ip.init_host_port(ip_str, 0));
return ip;
};
if (!http_query->get_arg("Answer").empty()) {
TRY_RESULT(answer, json_decode(http_query->get_arg("Answer")));
if (answer.type() != JsonValue::Type::Array) {
return Status::Error("Expected JSON array");
}
return get_ip_address(answer);
} else {
TRY_RESULT(json_value, json_decode(http_query->content_));
if (json_value.type() != JsonValue::Type::Object) {
return Status::Error("Failed to parse DNS result: not an object");
}
auto &object = json_value.get_object();
TRY_RESULT(answer, object.extract_required_field("Answer", JsonValue::Type::Array));
return get_ip_address(answer);
}
}
void on_result(Result<unique_ptr<HttpQuery>> r_http_query) {
auto end_time = Time::now();
auto result = get_ip_address(std::move(r_http_query));
VLOG(dns_resolver) << "Init IPv" << (prefer_ipv6_ ? "6" : "4") << " host = " << host_ << " in "
<< end_time - begin_time_ << " seconds to "
<< (result.is_ok() ? (PSLICE() << result.ok()) : CSlice("[invalid]"));
promise_.set_result(std::move(result));
stop();
}
};
class NativeDnsResolver final : public Actor {
public:
NativeDnsResolver(std::string host, bool prefer_ipv6, Promise<IPAddress> promise)
: host_(std::move(host)), prefer_ipv6_(prefer_ipv6), promise_(std::move(promise)) {
}
private:
std::string host_;
bool prefer_ipv6_;
Promise<IPAddress> promise_;
void start_up() final {
IPAddress ip;
auto begin_time = Time::now();
auto status = ip.init_host_port(host_, 0, prefer_ipv6_);
auto end_time = Time::now();
VLOG(dns_resolver) << "Init host = " << host_ << " in " << end_time - begin_time << " seconds to " << ip;
if (status.is_error()) {
promise_.set_error(std::move(status));
} else {
promise_.set_value(std::move(ip));
}
stop();
}
};
} // namespace detail
int VERBOSITY_NAME(dns_resolver) = VERBOSITY_NAME(DEBUG);
GetHostByNameActor::GetHostByNameActor(Options options) : options_(std::move(options)) {
CHECK(!options_.resolver_types.empty());
}
void GetHostByNameActor::run(string host, int port, bool prefer_ipv6, Promise<IPAddress> promise) {
auto r_ascii_host = idn_to_ascii(host);
if (r_ascii_host.is_error()) {
return promise.set_error(r_ascii_host.move_as_error());
}
auto ascii_host = r_ascii_host.move_as_ok();
if (ascii_host.empty()) {
return promise.set_error(Status::Error("Host is empty"));
}
auto begin_time = Time::now();
auto &value = cache_[prefer_ipv6].emplace(ascii_host, Value{{}, begin_time - 1.0}).first->second;
if (value.expires_at > begin_time) {
return promise.set_result(value.get_ip_port(port));
}
auto &query_ptr = active_queries_[prefer_ipv6][ascii_host];
if (query_ptr == nullptr) {
query_ptr = make_unique<Query>();
}
auto &query = *query_ptr;
query.promises.emplace_back(port, std::move(promise));
if (query.query.empty()) {
CHECK(query.promises.size() == 1);
query.real_host = std::move(host);
query.begin_time = Time::now();
run_query(std::move(ascii_host), prefer_ipv6, query);
}
}
void GetHostByNameActor::run_query(std::string host, bool prefer_ipv6, Query &query) {
auto promise = PromiseCreator::lambda([actor_id = actor_id(this), host, prefer_ipv6](Result<IPAddress> res) mutable {
send_closure(actor_id, &GetHostByNameActor::on_query_result, std::move(host), prefer_ipv6, std::move(res));
});
CHECK(query.query.empty());
CHECK(query.pos < options_.resolver_types.size());
auto resolver_type = options_.resolver_types[query.pos++];
query.query = [&] {
switch (resolver_type) {
case ResolverType::Native:
return ActorOwn<>(create_actor_on_scheduler<detail::NativeDnsResolver>(
"NativeDnsResolver", options_.scheduler_id, std::move(host), prefer_ipv6, std::move(promise)));
case ResolverType::Google:
return ActorOwn<>(create_actor_on_scheduler<detail::GoogleDnsResolver>(
"GoogleDnsResolver", options_.scheduler_id, std::move(host), prefer_ipv6, std::move(promise)));
default:
UNREACHABLE();
return ActorOwn<>();
}
}();
}
void GetHostByNameActor::on_query_result(std::string host, bool prefer_ipv6, Result<IPAddress> result) {
auto query_it = active_queries_[prefer_ipv6].find(host);
CHECK(query_it != active_queries_[prefer_ipv6].end());
auto &query = *query_it->second;
CHECK(!query.promises.empty());
CHECK(!query.query.empty());
if (result.is_error() && query.pos < options_.resolver_types.size()) {
query.query.reset();
return run_query(std::move(host), prefer_ipv6, query);
}
auto end_time = Time::now();
VLOG(dns_resolver) << "Init host = " << query.real_host << " in total of " << end_time - query.begin_time
<< " seconds to " << (result.is_ok() ? (PSLICE() << result.ok()) : CSlice("[invalid]"));
auto promises = std::move(query.promises);
auto value_it = cache_[prefer_ipv6].find(host);
CHECK(value_it != cache_[prefer_ipv6].end());
auto cache_timeout = result.is_ok() ? options_.ok_timeout : options_.error_timeout;
value_it->second = Value{std::move(result), end_time + cache_timeout};
active_queries_[prefer_ipv6].erase(query_it);
for (auto &promise : promises) {
promise.second.set_result(value_it->second.get_ip_port(promise.first));
}
}
} // namespace td

View File

@@ -0,0 +1,76 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/actor/actor.h"
#include "td/utils/common.h"
#include "td/utils/FlatHashMap.h"
#include "td/utils/logging.h"
#include "td/utils/port/IPAddress.h"
#include "td/utils/Promise.h"
#include "td/utils/Status.h"
#include <utility>
namespace td {
extern int VERBOSITY_NAME(dns_resolver);
class GetHostByNameActor final : public Actor {
public:
enum class ResolverType { Native, Google };
struct Options {
static constexpr int32 DEFAULT_CACHE_TIME = 60 * 29; // 29 minutes
static constexpr int32 DEFAULT_ERROR_CACHE_TIME = 60 * 5; // 5 minutes
vector<ResolverType> resolver_types{ResolverType::Native};
int32 scheduler_id{-1};
int32 ok_timeout{DEFAULT_CACHE_TIME};
int32 error_timeout{DEFAULT_ERROR_CACHE_TIME};
};
explicit GetHostByNameActor(Options options);
void run(std::string host, int port, bool prefer_ipv6, Promise<IPAddress> promise);
private:
void on_query_result(std::string host, bool prefer_ipv6, Result<IPAddress> result);
struct Value {
Result<IPAddress> ip;
double expires_at;
Value(Result<IPAddress> ip, double expires_at) : ip(std::move(ip)), expires_at(expires_at) {
}
Result<IPAddress> get_ip_port(int port) const {
auto result = ip.clone();
if (result.is_ok()) {
result.ok_ref().set_port(port);
}
return result;
}
};
FlatHashMap<string, Value> cache_[2];
struct Query {
ActorOwn<> query;
size_t pos = 0;
string real_host;
double begin_time = 0.0;
std::vector<std::pair<int, Promise<IPAddress>>> promises;
};
FlatHashMap<string, unique_ptr<Query>> active_queries_[2];
Options options_;
void run_query(std::string host, bool prefer_ipv6, Query &query);
};
} // namespace td

View File

@@ -0,0 +1,75 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/HttpChunkedByteFlow.h"
#include "td/utils/find_boundary.h"
#include "td/utils/format.h"
#include "td/utils/misc.h"
#include "td/utils/SliceBuilder.h"
#include "td/utils/Status.h"
namespace td {
bool HttpChunkedByteFlow::loop() {
bool result = false;
do {
if (state_ == State::ReadChunkLength) {
bool ok = find_boundary(input_->clone(), "\r\n", len_);
if (len_ > 8) {
finish(Status::Error(PSLICE() << "Too long length in chunked "
<< input_->cut_head(len_).move_as_buffer_slice().as_slice()));
return false;
}
if (!ok) {
set_need_size(input_->size() + 1);
break;
}
auto s_len = input_->cut_head(len_).move_as_buffer_slice();
input_->advance(2);
len_ = hex_to_integer<size_t>(s_len.as_slice());
save_len_ = len_;
state_ = State::ReadChunkContent;
}
auto size = input_->size();
auto ready = min(len_, size);
auto need_size = min(MIN_UPDATE_SIZE, len_) + 2;
if (size < need_size) {
set_need_size(need_size);
break;
}
if (total_size_ > MAX_SIZE - ready) {
finish(Status::Error(PSLICE() << "Too big query " << tag("size", input_->size())));
return false;
}
total_size_ += ready;
output_.append(input_->cut_head(ready));
result = true;
len_ -= ready;
if (len_ == 0) {
if (input_->size() < 2) {
set_need_size(2);
break;
}
input_->advance(2);
total_size_ += 2;
if (save_len_ == 0) {
finish(Status::OK());
return false;
}
state_ = State::ReadChunkLength;
}
} while (false);
if (!is_input_active_ && !result) {
finish(Status::Error("Unexpected end of stream"));
}
return result;
}
} // namespace td

View File

@@ -0,0 +1,29 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/ByteFlow.h"
#include <limits>
namespace td {
class HttpChunkedByteFlow final : public ByteFlowBase {
public:
bool loop() final;
private:
static constexpr size_t MAX_SIZE = std::numeric_limits<uint32>::max(); // some reasonable limit
static constexpr size_t MIN_UPDATE_SIZE = 1 << 14;
enum class State { ReadChunkLength, ReadChunkContent, OK };
State state_ = State::ReadChunkLength;
size_t len_ = 0;
size_t save_len_ = 0;
size_t total_size_ = 0;
};
} // namespace td

View File

@@ -0,0 +1,210 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/HttpConnectionBase.h"
#include "td/net/HttpHeaderCreator.h"
#include "td/utils/common.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/port/detail/PollableFd.h"
namespace td {
namespace detail {
HttpConnectionBase::HttpConnectionBase(State state, BufferedFd<SocketFd> fd, SslStream ssl_stream, size_t max_post_size,
size_t max_files, int32 idle_timeout, int32 slow_scheduler_id)
: state_(state)
, fd_(std::move(fd))
, ssl_stream_(std::move(ssl_stream))
, max_post_size_(max_post_size)
, max_files_(max_files)
, idle_timeout_(idle_timeout)
, slow_scheduler_id_(slow_scheduler_id) {
CHECK(state_ != State::Close);
if (ssl_stream_) {
read_source_ >> ssl_stream_.read_byte_flow() >> read_sink_;
write_source_ >> ssl_stream_.write_byte_flow() >> write_sink_;
} else {
read_source_ >> read_sink_;
write_source_ >> write_sink_;
}
peer_address_.init_peer_address(fd_).ignore();
}
void HttpConnectionBase::live_event() {
if (idle_timeout_ != 0) {
set_timeout_in(idle_timeout_);
}
}
void HttpConnectionBase::start_up() {
Scheduler::subscribe(fd_.get_poll_info().extract_pollable_fd(this));
reader_.init(read_sink_.get_output(), max_post_size_, max_files_);
if (state_ == State::Read) {
current_query_ = make_unique<HttpQuery>();
}
live_event();
yield();
}
void HttpConnectionBase::tear_down() {
Scheduler::unsubscribe_before_close(fd_.get_poll_info().get_pollable_fd_ref());
fd_.close();
}
void HttpConnectionBase::write_next_noflush(BufferSlice buffer) {
CHECK(state_ == State::Write);
write_buffer_.append(std::move(buffer));
}
void HttpConnectionBase::write_next(BufferSlice buffer) {
write_next_noflush(std::move(buffer));
loop();
}
void HttpConnectionBase::write_ok() {
CHECK(state_ == State::Write);
current_query_ = make_unique<HttpQuery>();
state_ = State::Read;
live_event();
loop();
}
void HttpConnectionBase::write_error(Status error) {
CHECK(state_ == State::Write);
LOG(WARNING) << "Close HTTP connection: " << error;
state_ = State::Close;
loop();
}
void HttpConnectionBase::timeout_expired() {
LOG(INFO) << "Idle timeout expired";
if (fd_.need_flush_write()) {
on_error(Status::Error("Write timeout expired"));
} else if (state_ == State::Read) {
on_error(Status::Error("Read timeout expired"));
}
stop();
}
void HttpConnectionBase::loop() {
if (ssl_stream_) {
//ssl_stream_.read_byte_flow().set_need_size(0);
ssl_stream_.write_byte_flow().reset_need_size();
}
sync_with_poll(fd_);
if (can_read_local(fd_)) {
LOG(DEBUG) << "Can read from the connection";
auto r = fd_.flush_read();
if (r.is_error()) {
if (!begins_with(r.error().message(), "SSL error {336134278")) { // if error is not yet outputted
LOG(INFO) << "Receive flush_read error: " << r.error();
}
on_error(Status::Error(r.error().public_message()));
return stop();
}
}
read_source_.wakeup();
// TODO: read_next even when state_ == State::Write
bool want_read = false;
bool can_be_slow = slow_scheduler_id_ == -1;
if (state_ == State::Read) {
auto res = reader_.read_next(current_query_.get(), can_be_slow);
if (res.is_error()) {
if (res.error().message() == "SLOW") {
LOG(INFO) << "Slow HTTP connection: migrate to " << slow_scheduler_id_;
CHECK(!can_be_slow);
yield();
migrate(slow_scheduler_id_);
slow_scheduler_id_ = -1;
return;
}
live_event();
state_ = State::Write;
if (res.error().code() == 500) {
LOG(WARNING) << "Failed to process an HTTP query: " << res.error();
} else {
LOG(INFO) << res.error();
}
HttpHeaderCreator hc;
hc.init_status_line(res.error().code());
hc.set_content_size(0);
write_buffer_.append(hc.finish().ok());
close_after_write_ = true;
on_error(Status::Error(res.error().public_message()));
} else if (res.ok() == 0) {
state_ = State::Write;
LOG(DEBUG) << "Send query to handler";
live_event();
current_query_->peer_address_ = peer_address_;
on_query(std::move(current_query_));
} else {
want_read = true;
}
}
write_source_.wakeup();
if (can_write_local(fd_)) {
LOG(DEBUG) << "Can write to the connection";
auto r = fd_.flush_write();
if (r.is_error()) {
LOG(INFO) << "Receive flush_write error: " << r.error();
on_error(Status::Error(r.error().public_message()));
}
if (close_after_write_ && !fd_.need_flush_write()) {
return stop();
}
}
Status pending_error;
if (fd_.get_poll_info().get_flags_local().has_pending_error()) {
pending_error = fd_.get_pending_error();
}
if (pending_error.is_ok() && write_sink_.status().is_error()) {
pending_error = std::move(write_sink_.status());
}
if (pending_error.is_ok() && read_sink_.status().is_error()) {
pending_error = std::move(read_sink_.status());
}
if (pending_error.is_error()) {
LOG(INFO) << pending_error;
if (!close_after_write_) {
on_error(Status::Error(pending_error.public_message()));
}
state_ = State::Close;
}
if (can_close_local(fd_)) {
LOG(DEBUG) << "Can close the connection";
state_ = State::Close;
}
if (state_ == State::Close) {
if (fd_.need_flush_write()) {
LOG(INFO) << "Close nonempty connection";
}
if (want_read && (!fd_.input_buffer().empty() || current_query_->type_ != HttpQuery::Type::Empty)) {
LOG(INFO) << "Close connection while reading request/response";
}
return stop();
}
}
void HttpConnectionBase::on_start_migrate(int32 sched_id) {
Scheduler::unsubscribe(fd_.get_poll_info().get_pollable_fd_ref());
}
void HttpConnectionBase::on_finish_migrate() {
Scheduler::subscribe(fd_.get_poll_info().extract_pollable_fd(this));
live_event();
}
} // namespace detail
} // namespace td

View File

@@ -0,0 +1,77 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/net/HttpQuery.h"
#include "td/net/HttpReader.h"
#include "td/net/SslStream.h"
#include "td/actor/actor.h"
#include "td/utils/buffer.h"
#include "td/utils/BufferedFd.h"
#include "td/utils/ByteFlow.h"
#include "td/utils/port/IPAddress.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/Status.h"
namespace td {
namespace detail {
class HttpConnectionBase : public Actor {
public:
void write_next_noflush(BufferSlice buffer);
void write_next(BufferSlice buffer);
void write_ok();
void write_error(Status error);
protected:
enum class State { Read, Write, Close };
HttpConnectionBase(State state, BufferedFd<SocketFd> fd, SslStream ssl_stream, size_t max_post_size, size_t max_files,
int32 idle_timeout, int32 slow_scheduler_id);
private:
State state_;
BufferedFd<SocketFd> fd_;
IPAddress peer_address_;
SslStream ssl_stream_;
ByteFlowSource read_source_{&fd_.input_buffer()};
ByteFlowSink read_sink_;
ChainBufferWriter write_buffer_;
ChainBufferReader write_buffer_reader_ = write_buffer_.extract_reader();
ByteFlowSource write_source_{&write_buffer_reader_};
ByteFlowMoveSink write_sink_{&fd_.output_buffer()};
size_t max_post_size_;
size_t max_files_;
int32 idle_timeout_;
HttpReader reader_;
unique_ptr<HttpQuery> current_query_;
bool close_after_write_ = false;
int32 slow_scheduler_id_{-1};
void live_event();
void start_up() final;
void tear_down() final;
void timeout_expired() final;
void loop() final;
void on_start_migrate(int32 sched_id) final;
void on_finish_migrate() final;
virtual void on_query(unique_ptr<HttpQuery> query) = 0;
virtual void on_error(Status error) = 0;
};
} // namespace detail
} // namespace td

View File

@@ -0,0 +1,36 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/HttpContentLengthByteFlow.h"
#include "td/utils/Status.h"
namespace td {
bool HttpContentLengthByteFlow::loop() {
auto ready_size = input_->size();
if (ready_size > len_) {
ready_size = len_;
}
auto need_size = min(MIN_UPDATE_SIZE, len_);
if (ready_size < need_size) {
set_need_size(need_size);
return false;
}
output_.append(input_->cut_head(ready_size));
len_ -= ready_size;
if (len_ == 0) {
finish(Status::OK());
return false;
}
if (!is_input_active_) {
finish(Status::Error("Unexpected end of stream"));
return false;
}
return true;
}
} // namespace td

View File

@@ -0,0 +1,25 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/ByteFlow.h"
namespace td {
class HttpContentLengthByteFlow final : public ByteFlowBase {
public:
HttpContentLengthByteFlow() = default;
explicit HttpContentLengthByteFlow(size_t len) : len_(len) {
}
bool loop() final;
private:
static constexpr size_t MIN_UPDATE_SIZE = 1 << 14;
size_t len_ = 0;
};
} // namespace td

View File

@@ -0,0 +1,25 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/HttpFile.h"
#include "td/net/HttpReader.h"
#include "td/utils/format.h"
namespace td {
HttpFile::~HttpFile() {
if (!temp_file_name.empty()) {
HttpReader::delete_temp_file(temp_file_name);
}
}
StringBuilder &operator<<(StringBuilder &sb, const HttpFile &file) {
return sb << tag("name", file.name) << tag("size", file.size);
}
} // namespace td

View File

@@ -0,0 +1,49 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/common.h"
#include "td/utils/StringBuilder.h"
namespace td {
class HttpFile {
public:
string field_name;
string name;
string content_type;
int64 size;
string temp_file_name;
HttpFile(string field_name, string name, string content_type, int64 size, string temp_file_name)
: field_name(std::move(field_name))
, name(std::move(name))
, content_type(std::move(content_type))
, size(size)
, temp_file_name(std::move(temp_file_name)) {
}
HttpFile(const HttpFile &) = delete;
HttpFile &operator=(const HttpFile &) = delete;
HttpFile(HttpFile &&other) noexcept
: field_name(std::move(other.field_name))
, name(std::move(other.name))
, content_type(std::move(other.content_type))
, size(other.size)
, temp_file_name(std::move(other.temp_file_name)) {
other.temp_file_name.clear();
}
HttpFile &operator=(HttpFile &&) = delete;
~HttpFile();
};
StringBuilder &operator<<(StringBuilder &sb, const HttpFile &file);
} // namespace td

View File

@@ -0,0 +1,153 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/logging.h"
#include "td/utils/Slice.h"
#include "td/utils/SliceBuilder.h"
#include "td/utils/Status.h"
#include "td/utils/StringBuilder.h"
namespace td {
class HttpHeaderCreator {
public:
static constexpr size_t MAX_HEADER = 4096;
HttpHeaderCreator() : sb_(MutableSlice{header_, MAX_HEADER}) {
}
void init_ok() {
sb_ = StringBuilder(MutableSlice{header_, MAX_HEADER});
sb_ << "HTTP/1.1 200 OK\r\n";
}
void init_get(Slice url) {
sb_ = StringBuilder(MutableSlice{header_, MAX_HEADER});
sb_ << "GET " << url << " HTTP/1.1\r\n";
}
void init_post(Slice url) {
sb_ = StringBuilder(MutableSlice{header_, MAX_HEADER});
sb_ << "POST " << url << " HTTP/1.1\r\n";
}
void init_error(int code, Slice reason) {
sb_ = StringBuilder(MutableSlice{header_, MAX_HEADER});
sb_ << "HTTP/1.1 " << code << " " << reason << "\r\n";
}
void init_status_line(int http_status_code) {
init_error(http_status_code, get_status_line(http_status_code));
}
void add_header(Slice key, Slice value) {
sb_ << key << ": " << value << "\r\n";
}
void set_content_type(Slice type) {
add_header("Content-Type", type);
}
void set_content_size(size_t size) {
add_header("Content-Length", PSLICE() << size);
}
void set_keep_alive() {
add_header("Connection", "keep-alive");
}
Result<Slice> finish(Slice content = {}) TD_WARN_UNUSED_RESULT {
sb_ << "\r\n";
if (!content.empty()) {
sb_ << content;
}
if (sb_.is_error()) {
return Status::Error("Too many headers");
}
return sb_.as_cslice();
}
private:
static CSlice get_status_line(int http_status_code) {
if (http_status_code == 200) {
return CSlice("OK");
}
switch (http_status_code) {
case 201:
return CSlice("Created");
case 202:
return CSlice("Accepted");
case 204:
return CSlice("No Content");
case 206:
return CSlice("Partial Content");
case 301:
return CSlice("Moved Permanently");
case 302:
return CSlice("Found");
case 303:
return CSlice("See Other");
case 304:
return CSlice("Not Modified");
case 307:
return CSlice("Temporary Redirect");
case 308:
return CSlice("Permanent Redirect");
case 400:
return CSlice("Bad Request");
case 401:
return CSlice("Unauthorized");
case 403:
return CSlice("Forbidden");
case 404:
return CSlice("Not Found");
case 405:
return CSlice("Method Not Allowed");
case 406:
return CSlice("Not Acceptable");
case 408:
return CSlice("Request Timeout");
case 409:
return CSlice("Conflict");
case 410:
return CSlice("Gone");
case 411:
return CSlice("Length Required");
case 412:
return CSlice("Precondition Failed");
case 413:
return CSlice("Request Entity Too Large");
case 414:
return CSlice("Request-URI Too Long");
case 415:
return CSlice("Unsupported Media Type");
case 416:
return CSlice("Range Not Satisfiable");
case 417:
return CSlice("Expectation Failed");
case 418:
return CSlice("I'm a teapot");
case 421:
return CSlice("Misdirected Request");
case 426:
return CSlice("Upgrade Required");
case 429:
return CSlice("Too Many Requests");
case 431:
return CSlice("Request Header Fields Too Large");
case 480:
return CSlice("Temporarily Unavailable");
case 501:
return CSlice("Not Implemented");
case 502:
return CSlice("Bad Gateway");
case 503:
return CSlice("Service Unavailable");
case 505:
return CSlice("HTTP Version Not Supported");
default:
LOG_IF(ERROR, http_status_code != 500) << "Unsupported status code " << http_status_code << " returned";
return CSlice("Internal Server Error");
}
}
char header_[MAX_HEADER];
StringBuilder sb_;
};
} // namespace td

View File

@@ -0,0 +1,32 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/HttpInboundConnection.h"
#include "td/net/SslStream.h"
#include "td/utils/common.h"
namespace td {
HttpInboundConnection::HttpInboundConnection(BufferedFd<SocketFd> fd, size_t max_post_size, size_t max_files,
int32 idle_timeout, ActorShared<Callback> callback,
int32 slow_scheduler_id)
: HttpConnectionBase(State::Read, std::move(fd), SslStream(), max_post_size, max_files, idle_timeout,
slow_scheduler_id)
, callback_(std::move(callback)) {
}
void HttpInboundConnection::on_query(unique_ptr<HttpQuery> query) {
CHECK(!callback_.empty());
send_closure(callback_, &Callback::handle, std::move(query), ActorOwn<HttpInboundConnection>(actor_id(this)));
}
void HttpInboundConnection::on_error(Status error) {
// nothing to do
}
} // namespace td

View File

@@ -0,0 +1,44 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/net/HttpConnectionBase.h"
#include "td/net/HttpQuery.h"
#include "td/actor/actor.h"
#include "td/utils/BufferedFd.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/Status.h"
namespace td {
class HttpInboundConnection final : public detail::HttpConnectionBase {
public:
class Callback : public Actor {
public:
virtual void handle(unique_ptr<HttpQuery> query, ActorOwn<HttpInboundConnection> connection) = 0;
};
// Inherited interface
// void write_next(BufferSlice buffer);
// void write_ok();
// void write_error(Status error);
HttpInboundConnection(BufferedFd<SocketFd> fd, size_t max_post_size, size_t max_files, int32 idle_timeout,
ActorShared<Callback> callback, int32 slow_scheduler_id = -1);
private:
void on_query(unique_ptr<HttpQuery> query) final;
void on_error(Status error) final;
void hangup() final {
callback_.release();
stop();
}
ActorShared<Callback> callback_;
};
} // namespace td

View File

@@ -0,0 +1,23 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/HttpOutboundConnection.h"
#include "td/utils/common.h"
namespace td {
void HttpOutboundConnection::on_query(unique_ptr<HttpQuery> query) {
CHECK(!callback_.empty());
send_closure(callback_, &Callback::handle, std::move(query));
}
void HttpOutboundConnection::on_error(Status error) {
CHECK(!callback_.empty());
send_closure(callback_, &Callback::on_connection_error, std::move(error));
}
} // namespace td

View File

@@ -0,0 +1,49 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/net/HttpConnectionBase.h"
#include "td/net/HttpQuery.h"
#include "td/net/SslStream.h"
#include "td/actor/actor.h"
#include "td/utils/BufferedFd.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/Status.h"
namespace td {
class HttpOutboundConnection final : public detail::HttpConnectionBase {
public:
class Callback : public Actor {
public:
virtual void handle(unique_ptr<HttpQuery> query) = 0;
virtual void on_connection_error(Status error) = 0; // TODO rename to on_error
};
HttpOutboundConnection(BufferedFd<SocketFd> fd, SslStream ssl_stream, size_t max_post_size, size_t max_files,
int32 idle_timeout, ActorShared<Callback> callback, int32 slow_scheduler_id = -1)
: HttpConnectionBase(HttpConnectionBase::State::Write, std::move(fd), std::move(ssl_stream), max_post_size,
max_files, idle_timeout, slow_scheduler_id)
, callback_(std::move(callback)) {
}
// Inherited interface
// void write_next(BufferSlice buffer);
// void write_ok();
// void write_error(Status error);
private:
void on_query(unique_ptr<HttpQuery> query) final;
void on_error(Status error) final;
void hangup() final {
callback_.release();
HttpConnectionBase::hangup();
}
ActorShared<Callback> callback_;
};
} // namespace td

View File

@@ -0,0 +1,111 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/HttpProxy.h"
#include "td/utils/base64.h"
#include "td/utils/common.h"
#include "td/utils/format.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/Slice.h"
#include "td/utils/SliceBuilder.h"
namespace td {
void HttpProxy::send_connect() {
VLOG(proxy) << "Send CONNECT to proxy";
CHECK(state_ == State::SendConnect);
state_ = State::WaitConnectResponse;
string host = PSTRING() << ip_address_.get_ip_host() << ':' << ip_address_.get_port();
string proxy_authorization;
if (!username_.empty() || !password_.empty()) {
auto userinfo = PSTRING() << username_ << ':' << password_;
proxy_authorization = PSTRING() << "Proxy-Authorization: Basic " << base64_encode(userinfo) << "\r\n";
VLOG(proxy) << "Use credentials to connect to proxy: " << proxy_authorization;
}
fd_.output_buffer().append(PSLICE() << "CONNECT " << host << " HTTP/1.1\r\n"
<< "Host: " << host << "\r\n"
<< proxy_authorization << "\r\n");
}
Status HttpProxy::wait_connect_response() {
CHECK(state_ == State::WaitConnectResponse);
auto it = fd_.input_buffer().clone();
VLOG(proxy) << "Receive CONNECT response of size " << it.size();
if (it.size() < 12 + 1 + 1) {
return Status::OK();
}
char begin_buf[12];
MutableSlice begin(begin_buf, 12);
it.advance(12, begin);
if ((begin.substr(0, 10) != "HTTP/1.1 2" && begin.substr(0, 10) != "HTTP/1.0 2") || !is_digit(begin[10]) ||
!is_digit(begin[11])) {
char buf[1024];
size_t len = min(sizeof(buf), it.size());
it.advance(len, MutableSlice{buf, sizeof(buf)});
VLOG(proxy) << "Failed to connect: " << format::escaped(begin) << format::escaped(Slice(buf, len));
return Status::Error(PSLICE() << "Failed to connect to " << ip_address_.get_ip_host() << ':'
<< ip_address_.get_port());
}
size_t total_size = 12;
char c;
MutableSlice c_slice(&c, 1);
while (!it.empty()) {
it.advance(1, c_slice);
total_size++;
if (c == '\n') {
break;
}
}
if (it.empty()) {
return Status::OK();
}
char prev = '\n';
size_t pos = 0;
bool found = false;
while (!it.empty()) {
it.advance(1, c_slice);
total_size++;
if (c == '\n') {
if (pos == 0 || (pos == 1 && prev == '\r')) {
found = true;
break;
}
pos = 0;
} else {
pos++;
}
prev = c;
}
if (!found) {
CHECK(it.empty());
return Status::OK();
}
fd_.input_buffer().advance(total_size);
stop();
return Status::OK();
}
Status HttpProxy::loop_impl() {
switch (state_) {
case State::SendConnect:
send_connect();
break;
case State::WaitConnectResponse:
TRY_STATUS(wait_connect_response());
break;
default:
UNREACHABLE();
}
return Status::OK();
}
} // namespace td

View File

@@ -0,0 +1,28 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/net/TransparentProxy.h"
#include "td/utils/Status.h"
namespace td {
class HttpProxy final : public TransparentProxy {
public:
using TransparentProxy::TransparentProxy;
private:
enum class State { SendConnect, WaitConnectResponse } state_ = State::SendConnect;
void send_connect();
Status wait_connect_response();
Status loop_impl() final;
};
} // namespace td

View File

@@ -0,0 +1,86 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/HttpQuery.h"
#include "td/utils/misc.h"
#include <algorithm>
namespace td {
Slice HttpQuery::get_header(Slice key) const {
auto it = std::find_if(headers_.begin(), headers_.end(),
[&key](const std::pair<MutableSlice, MutableSlice> &s) { return s.first == key; });
return it == headers_.end() ? Slice() : it->second;
}
MutableSlice HttpQuery::get_arg(Slice key) const {
auto it = std::find_if(args_.begin(), args_.end(),
[&key](const std::pair<MutableSlice, MutableSlice> &s) { return s.first == key; });
return it == args_.end() ? MutableSlice() : it->second;
}
vector<std::pair<string, string>> HttpQuery::get_args() const {
vector<std::pair<string, string>> res;
res.reserve(args_.size());
for (auto &it : args_) {
res.emplace_back(it.first.str(), it.second.str());
}
return res;
}
int HttpQuery::get_retry_after() const {
auto value = get_header("retry-after");
if (value.empty()) {
return 0;
}
auto r_retry_after = to_integer_safe<int>(value);
if (r_retry_after.is_error()) {
return 0;
}
return td::max(0, r_retry_after.ok());
}
StringBuilder &operator<<(StringBuilder &sb, const HttpQuery &q) {
switch (q.type_) {
case HttpQuery::Type::Empty:
sb << "EMPTY";
return sb;
case HttpQuery::Type::Get:
sb << "GET";
break;
case HttpQuery::Type::Post:
sb << "POST";
break;
case HttpQuery::Type::Response:
sb << "RESPONSE";
break;
}
if (q.type_ == HttpQuery::Type::Response) {
sb << ":" << q.code_ << ":" << q.reason_;
} else {
sb << ":" << q.url_path_;
for (auto &key_value : q.args_) {
sb << ":[" << key_value.first << ":" << key_value.second << "]";
}
}
if (q.keep_alive_) {
sb << ":keep-alive";
}
sb << "\n";
for (auto &key_value : q.headers_) {
sb << key_value.first << "=" << key_value.second << "\n";
}
sb << "BEGIN CONTENT\n";
sb << q.content_;
sb << "END CONTENT\n";
return sb;
}
} // namespace td

View File

@@ -0,0 +1,50 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/net/HttpFile.h"
#include "td/utils/buffer.h"
#include "td/utils/common.h"
#include "td/utils/port/IPAddress.h"
#include "td/utils/Slice.h"
#include "td/utils/StringBuilder.h"
#include <utility>
namespace td {
class HttpQuery {
public:
enum class Type : int8 { Empty, Get, Post, Response };
vector<BufferSlice> container_;
Type type_ = Type::Empty;
int32 code_ = 0;
MutableSlice url_path_;
vector<std::pair<MutableSlice, MutableSlice>> args_;
MutableSlice reason_;
bool keep_alive_ = true;
vector<std::pair<MutableSlice, MutableSlice>> headers_;
vector<HttpFile> files_;
MutableSlice content_;
IPAddress peer_address_;
Slice get_header(Slice key) const;
MutableSlice get_arg(Slice key) const;
vector<std::pair<string, string>> get_args() const;
int get_retry_after() const;
};
StringBuilder &operator<<(StringBuilder &sb, const HttpQuery &q);
} // namespace td

View File

@@ -0,0 +1,889 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/HttpReader.h"
#include "td/utils/filesystem.h"
#include "td/utils/find_boundary.h"
#include "td/utils/format.h"
#include "td/utils/Gzip.h"
#include "td/utils/HttpUrl.h"
#include "td/utils/JsonBuilder.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/Parser.h"
#include "td/utils/PathView.h"
#include "td/utils/port/path.h"
#include "td/utils/SliceBuilder.h"
#include <cstddef>
#include <cstring>
namespace td {
constexpr const char HttpReader::TEMP_DIRECTORY_PREFIX[];
void HttpReader::init(ChainBufferReader *input, size_t max_post_size, size_t max_files) {
input_ = input;
state_ = State::ReadHeaders;
headers_read_length_ = 0;
content_length_ = -1;
query_ = nullptr;
max_post_size_ = max_post_size;
max_files_ = max_files;
total_parameters_length_ = 0;
total_headers_length_ = 0;
}
Result<size_t> HttpReader::read_next(HttpQuery *query, bool can_be_slow) {
if (query_ != query) {
CHECK(query_ == nullptr);
query_ = query;
}
auto r_size = do_read_next(can_be_slow);
if (state_ != State::ReadHeaders && flow_sink_.is_ready() && r_size.is_ok() && r_size.ok() > 0) {
CHECK(flow_sink_.status().is_ok());
return Status::Error(400, "Bad Request: unexpected end of request content");
}
return r_size;
}
Result<size_t> HttpReader::do_read_next(bool can_be_slow) {
size_t need_size = input_->size() + 1;
while (true) {
if (state_ != State::ReadHeaders) {
gzip_flow_.wakeup();
flow_source_.wakeup();
if (flow_sink_.is_ready() && flow_sink_.status().is_error()) {
if (!temp_file_.empty()) {
clean_temporary_file();
}
return Status::Error(400, PSLICE() << "Bad Request: " << flow_sink_.status().message());
}
need_size = flow_source_.get_need_size();
if (need_size == 0) {
need_size = input_->size() + 1;
}
}
switch (state_) {
case State::ReadHeaders: {
auto result = split_header();
if (result.is_error() || result.ok() != 0) {
return result;
}
if (transfer_encoding_.empty() && content_length_ <= 0) {
break;
}
flow_source_ = ByteFlowSource(input_);
ByteFlowInterface *source = &flow_source_;
if (transfer_encoding_.empty()) {
content_length_flow_ = HttpContentLengthByteFlow(narrow_cast<size_t>(content_length_));
*source >> content_length_flow_;
source = &content_length_flow_;
} else if (transfer_encoding_ == "chunked") {
chunked_flow_ = HttpChunkedByteFlow();
*source >> chunked_flow_;
source = &chunked_flow_;
} else {
LOG(ERROR) << "Unsupported " << tag("transfer-encoding", transfer_encoding_);
return Status::Error(501, "Unimplemented: unsupported transfer-encoding");
}
if (content_encoding_.empty() || content_encoding_ == "none") {
} else if (content_encoding_ == "gzip" || content_encoding_ == "deflate") {
gzip_flow_ = GzipByteFlow(Gzip::Mode::Decode);
GzipByteFlow::Options options;
options.write_watermark.low = 0;
options.write_watermark.high = max(max_post_size_, MAX_TOTAL_PARAMETERS_LENGTH + 1);
gzip_flow_.set_options(options);
gzip_flow_.set_max_output_size(MAX_CONTENT_SIZE);
*source >> gzip_flow_;
source = &gzip_flow_;
} else {
LOG(WARNING) << "Unsupported " << tag("content-encoding", content_encoding_);
return Status::Error(415, "Unsupported Media Type: unsupported content-encoding");
}
flow_sink_ = ByteFlowSink();
*source >> flow_sink_;
content_ = flow_sink_.get_output();
if (content_length_ >= static_cast<int64>(MAX_CONTENT_SIZE)) {
return Status::Error(413, PSLICE() << "Request Entity Too Large: content length is " << content_length_);
}
if (content_type_lowercased_.find("multipart/form-data") != string::npos) {
state_ = State::ReadMultipartFormData;
const char *p = std::strstr(content_type_lowercased_.c_str(), "boundary");
if (p == nullptr) {
return Status::Error(400, "Bad Request: boundary not found");
}
p += 8;
std::ptrdiff_t offset = p - content_type_lowercased_.c_str();
p = static_cast<const char *>(
std::memchr(content_type_.begin() + offset, '=', content_type_.size() - offset));
if (p == nullptr) {
return Status::Error(400, "Bad Request: boundary value not found");
}
p++;
auto end_p = static_cast<const char *>(std::memchr(p, ';', content_type_.end() - p));
if (end_p == nullptr) {
end_p = content_type_.end();
}
if (*p == '"' && p + 1 < end_p && end_p[-1] == '"') {
p++;
end_p--;
}
CHECK(p != nullptr);
Slice boundary(p, static_cast<size_t>(end_p - p));
if (boundary.empty() || boundary.size() > MAX_BOUNDARY_LENGTH) {
return Status::Error(400, "Bad Request: boundary too big or empty");
}
boundary_ = "\r\n--" + boundary.str();
form_data_parse_state_ = FormDataParseState::SkipPrologue;
form_data_read_length_ = 0;
form_data_skipped_length_ = 0;
} else if (content_type_lowercased_.find("application/x-www-form-urlencoded") != string::npos ||
content_type_lowercased_.find("application/json") != string::npos) {
state_ = State::ReadArgs;
} else {
form_data_skipped_length_ = 0;
state_ = State::ReadContent;
}
continue;
}
case State::ReadContent: {
if (content_->size() > max_post_size_) {
state_ = State::ReadContentToFile;
GzipByteFlow::Options options;
options.write_watermark.low = 4 << 20;
options.write_watermark.high = 8 << 20;
gzip_flow_.set_options(options);
continue;
}
if (flow_sink_.is_ready()) {
CHECK(query_->container_.size() == 1u);
query_->container_.emplace_back(content_->cut_head(content_->size()).move_as_buffer_slice());
query_->content_ = query_->container_.back().as_mutable_slice();
break;
}
return need_size;
}
case State::ReadContentToFile: {
if (!can_be_slow) {
return Status::Error("SLOW");
}
// save content to a file
if (temp_file_.empty()) {
auto open_status = open_temp_file("file");
if (open_status.is_error()) {
return Status::Error(500, "Internal Server Error: can't create temporary file");
}
}
auto size = content_->size();
bool restart = false;
if (size > (1 << 20) || flow_sink_.is_ready()) {
TRY_STATUS(save_file_part(content_->cut_head(size).move_as_buffer_slice()));
restart = true;
}
if (flow_sink_.is_ready()) {
query_->files_.emplace_back("file", "", content_type_.str(), file_size_, temp_file_name_);
close_temp_file();
break;
}
if (restart) {
continue;
}
return need_size;
}
case State::ReadArgs: {
auto size = content_->size();
if (size > MAX_TOTAL_PARAMETERS_LENGTH - total_parameters_length_) {
return Status::Error(413, "Request Entity Too Large: too many parameters");
}
if (flow_sink_.is_ready()) {
query_->container_.emplace_back(content_->cut_head(size).move_as_buffer_slice());
Status result;
if (content_type_lowercased_.find("application/x-www-form-urlencoded") != string::npos) {
result = parse_parameters(query_->container_.back().as_mutable_slice());
} else {
result = parse_json_parameters(query_->container_.back().as_mutable_slice());
}
if (result.is_error()) {
if (result.code() == 413) {
return std::move(result);
}
LOG(INFO) << result.message();
}
query_->content_ = MutableSlice();
break;
}
return need_size;
}
case State::ReadMultipartFormData: {
if (!content_->empty() || flow_sink_.is_ready()) {
TRY_RESULT(result, parse_multipart_form_data(can_be_slow));
if (result) {
break;
}
}
return need_size;
}
default:
UNREACHABLE();
}
break;
}
init(input_, max_post_size_, max_files_);
return 0;
}
// returns Status on wrong request
// returns true if parsing has finished
// returns false if need more data
Result<bool> HttpReader::parse_multipart_form_data(bool can_be_slow) {
while (true) {
LOG(DEBUG) << "Parsing multipart form data in state " << static_cast<int32>(form_data_parse_state_)
<< " with already read length " << form_data_read_length_;
switch (form_data_parse_state_) {
case FormDataParseState::SkipPrologue:
if (find_boundary(content_->clone(), {boundary_.c_str() + 2, boundary_.size() - 2}, form_data_read_length_)) {
size_t to_skip = form_data_read_length_ + (boundary_.size() - 2);
content_->advance(to_skip);
form_data_skipped_length_ += to_skip;
form_data_read_length_ = 0;
form_data_parse_state_ = FormDataParseState::ReadPartHeaders;
continue;
}
content_->advance(form_data_read_length_);
form_data_skipped_length_ += form_data_read_length_;
form_data_read_length_ = 0;
return false;
case FormDataParseState::ReadPartHeaders:
if (find_boundary(content_->clone(), "\r\n\r\n", form_data_read_length_)) {
total_headers_length_ += form_data_read_length_;
if (total_headers_length_ > MAX_TOTAL_HEADERS_LENGTH) {
return Status::Error(431, "Request Header Fields Too Large: total headers size exceeded");
}
if (form_data_read_length_ == 0) {
// there are no headers at all
return Status::Error(400, "Bad Request: headers in multipart/form-data are empty");
}
content_->advance(2); // "\r\n" after boundary
auto headers = content_->cut_head(form_data_read_length_).move_as_buffer_slice();
CHECK(headers.as_slice().size() == form_data_read_length_);
LOG(DEBUG) << "Parse headers in multipart form data: \"" << headers.as_slice() << "\"";
content_->advance(2);
form_data_skipped_length_ += form_data_read_length_ + 4;
form_data_read_length_ = 0;
field_name_ = MutableSlice();
file_field_name_.clear();
field_content_type_ = "application/octet-stream";
file_name_.clear();
has_file_name_ = false;
CHECK(temp_file_.empty());
temp_file_name_.clear();
Parser headers_parser(headers.as_mutable_slice());
while (headers_parser.status().is_ok() && !headers_parser.data().empty()) {
MutableSlice header_name = headers_parser.read_till(':');
headers_parser.skip(':');
char *header_value_begin = headers_parser.ptr();
char *header_value_end;
do {
headers_parser.read_till('\r');
header_value_end = headers_parser.ptr();
headers_parser.skip('\r');
headers_parser.skip('\n');
} while (headers_parser.status().is_ok() &&
(headers_parser.peek_char() == ' ' || headers_parser.peek_char() == '\t'));
MutableSlice header_value(header_value_begin, header_value_end);
header_name = trim(header_name);
header_value = trim(header_value);
to_lower_inplace(header_name);
if (header_name == "content-disposition") {
if (header_value.substr(0, 10) != "form-data;") {
return Status::Error(400, "Bad Request: expected form-data content disposition");
}
header_value.remove_prefix(10);
while (true) {
header_value = trim(header_value);
const auto *key_end =
static_cast<const char *>(std::memchr(header_value.data(), '=', header_value.size()));
if (key_end == nullptr) {
break;
}
size_t key_size = key_end - header_value.data();
auto key = trim(header_value.substr(0, key_size));
header_value.remove_prefix(key_size + 1);
while (!header_value.empty() && is_space(header_value[0])) {
header_value.remove_prefix(1);
}
MutableSlice value;
if (!header_value.empty() && header_value[0] == '"') { // quoted-string
char *value_end = header_value.data() + 1;
const char *pos = value_end;
while (true) {
if (pos == header_value.data() + header_value.size()) {
return Status::Error(400, "Bad Request: unclosed quoted string in Content-Disposition header");
}
char c = *pos++;
if (c == '"') {
break;
}
if (c == '\\') {
if (pos == header_value.data() + header_value.size()) {
return Status::Error(400, "Bad Request: wrong escape sequence in Content-Disposition header");
}
c = *pos++;
}
*value_end++ = c;
}
value = header_value.substr(1, value_end - header_value.data() - 1);
header_value.remove_prefix(pos - header_value.data());
while (!header_value.empty() && is_space(header_value[0])) {
header_value.remove_prefix(1);
}
if (!header_value.empty()) {
if (header_value[0] != ';') {
return Status::Error(400, "Bad Request: expected ';' in Content-Disposition header");
}
header_value.remove_prefix(1);
}
} else { // token
auto value_end =
static_cast<const char *>(std::memchr(header_value.data(), ';', header_value.size()));
if (value_end != nullptr) {
auto value_size = static_cast<size_t>(value_end - header_value.data());
value = trim(header_value.substr(0, value_size));
header_value.remove_prefix(value_size + 1);
} else {
value = trim(header_value);
header_value = MutableSlice();
}
}
value = url_decode_inplace(value, false);
if (key == "name") {
field_name_ = value;
} else if (key == "filename") {
file_name_ = value.str();
has_file_name_ = true;
} else {
// ignore unknown parts of header
}
}
} else if (header_name == "content-type") {
field_content_type_ = header_value.str();
} else {
// ignore unknown header
}
}
if (headers_parser.status().is_error()) {
return Status::Error(400, "Bad Request: can't parse form data headers");
}
if (field_name_.empty()) {
return Status::Error(400, "Bad Request: field name in multipart/form-data not found");
}
if (has_file_name_) {
// file
if (query_->files_.size() == max_files_) {
return Status::Error(413, "Request Entity Too Large: too many files attached");
}
// don't need to save headers for files
file_field_name_ = field_name_.str();
form_data_parse_state_ = FormDataParseState::ReadFile;
} else {
// save headers for query parameters. They contain header names
query_->container_.push_back(std::move(headers));
form_data_parse_state_ = FormDataParseState::ReadPartValue;
}
continue;
}
if (total_headers_length_ + form_data_read_length_ > MAX_TOTAL_HEADERS_LENGTH) {
return Status::Error(431, "Request Header Fields Too Large: total headers size exceeded");
}
return false;
case FormDataParseState::ReadPartValue:
if (find_boundary(content_->clone(), boundary_, form_data_read_length_)) {
if (total_parameters_length_ + form_data_read_length_ > MAX_TOTAL_PARAMETERS_LENGTH) {
return Status::Error(413, "Request Entity Too Large: too many parameters in form data");
}
query_->container_.emplace_back(content_->cut_head(form_data_read_length_).move_as_buffer_slice());
MutableSlice value = query_->container_.back().as_mutable_slice();
content_->advance(boundary_.size());
form_data_skipped_length_ += form_data_read_length_ + boundary_.size();
form_data_read_length_ = 0;
if (begins_with(field_content_type_, "application/x-www-form-urlencoded")) {
// treat value as ordinary parameters
auto result = parse_parameters(value);
if (result.is_error()) {
return std::move(result);
}
} else {
total_parameters_length_ += form_data_read_length_;
LOG(DEBUG) << "Get ordinary parameter in multipart form data: \"" << field_name_ << "\": \"" << value
<< "\"";
query_->args_.emplace_back(field_name_, value);
}
form_data_parse_state_ = FormDataParseState::CheckForLastBoundary;
continue;
}
CHECK(content_->size() < form_data_read_length_ + boundary_.size());
if (total_parameters_length_ + form_data_read_length_ > MAX_TOTAL_PARAMETERS_LENGTH) {
return Status::Error(413, "Request Entity Too Large: too many parameters in form data");
}
return false;
case FormDataParseState::ReadFile: {
if (!can_be_slow) {
return Status::Error("SLOW");
}
if (temp_file_.empty()) {
auto open_status = open_temp_file(file_name_);
if (open_status.is_error()) {
return Status::Error(500, "Internal Server Error: can't create temporary file");
}
}
if (find_boundary(content_->clone(), boundary_, form_data_read_length_)) {
auto file_part = content_->cut_head(form_data_read_length_).move_as_buffer_slice();
content_->advance(boundary_.size());
form_data_skipped_length_ += form_data_read_length_ + boundary_.size();
form_data_read_length_ = 0;
TRY_STATUS(save_file_part(std::move(file_part)));
query_->files_.emplace_back(file_field_name_, file_name_, field_content_type_, file_size_, temp_file_name_);
close_temp_file();
form_data_parse_state_ = FormDataParseState::CheckForLastBoundary;
continue;
}
// TODO optimize?
auto file_part = content_->cut_head(form_data_read_length_).move_as_buffer_slice();
form_data_skipped_length_ += form_data_read_length_;
form_data_read_length_ = 0;
CHECK(content_->size() < boundary_.size());
TRY_STATUS(save_file_part(std::move(file_part)));
return false;
}
case FormDataParseState::CheckForLastBoundary: {
if (content_->size() < 2) {
// need more data
return false;
}
auto range = content_->clone();
char x[2];
range.advance(2, {x, 2});
if (x[0] == '-' && x[1] == '-') {
content_->advance(2);
form_data_skipped_length_ += 2;
form_data_parse_state_ = FormDataParseState::SkipEpilogue;
} else {
form_data_parse_state_ = FormDataParseState::ReadPartHeaders;
}
continue;
}
case FormDataParseState::SkipEpilogue: {
size_t size = content_->size();
LOG(DEBUG) << "Skipping epilogue. Have " << size << " bytes";
content_->advance(size);
form_data_skipped_length_ += size;
// TODO(now): check if form_data_skipped_length is too big
return flow_sink_.is_ready();
}
default:
UNREACHABLE();
}
break;
}
return true;
}
Result<size_t> HttpReader::split_header() {
if (find_boundary(input_->clone(), "\r\n\r\n", headers_read_length_)) {
query_->container_.clear();
auto a = input_->cut_head(headers_read_length_ + 2);
auto b = a.move_as_buffer_slice();
query_->container_.emplace_back(std::move(b));
// query_->container_.emplace_back(input_->cut_head(headers_read_length_ + 2).move_as_buffer_slice());
CHECK(query_->container_.back().size() == headers_read_length_ + 2);
input_->advance(2);
total_headers_length_ = headers_read_length_;
auto status = parse_head(query_->container_.back().as_mutable_slice());
if (status.is_error()) {
return std::move(status);
}
return 0;
}
if (input_->size() > MAX_TOTAL_HEADERS_LENGTH) {
return Status::Error(431, "Request Header Fields Too Large: total headers size exceeded");
}
return input_->size() + 1;
}
void HttpReader::process_header(MutableSlice header_name, MutableSlice header_value) {
header_name = trim(header_name);
header_value = trim(header_value); // TODO need to remove "\r\n" from value
to_lower_inplace(header_name);
LOG(DEBUG) << "Process header [" << header_name << "=>" << header_value << "]";
query_->headers_.emplace_back(header_name, header_value);
if (header_name == "content-length") {
auto content_length = to_integer<uint64>(header_value);
if (content_length > MAX_CONTENT_SIZE) {
content_length = MAX_CONTENT_SIZE;
}
content_length_ = static_cast<int64>(content_length);
} else if (header_name == "connection") {
to_lower_inplace(header_value);
if (header_value == "close") {
query_->keep_alive_ = false;
} else {
query_->keep_alive_ = true;
}
} else if (header_name == "content-type") {
content_type_ = header_value;
content_type_lowercased_ = header_value.str();
to_lower_inplace(content_type_lowercased_);
} else if (header_name == "content-encoding") {
to_lower_inplace(header_value);
content_encoding_ = header_value;
} else if (header_name == "transfer-encoding") {
to_lower_inplace(header_value);
transfer_encoding_ = header_value;
}
}
Status HttpReader::parse_url(MutableSlice url) {
size_t url_path_size = 0;
while (url_path_size < url.size() && url[url_path_size] != '?' && url[url_path_size] != '#') {
url_path_size++;
}
query_->url_path_ = url_decode_inplace({url.data(), url_path_size}, false);
if (url_path_size == url.size() || url[url_path_size] != '?') {
return Status::OK();
}
return parse_parameters(url.substr(url_path_size + 1));
}
Status HttpReader::parse_parameters(MutableSlice parameters) {
total_parameters_length_ += parameters.size();
if (total_parameters_length_ > MAX_TOTAL_PARAMETERS_LENGTH) {
return Status::Error(413, "Request Entity Too Large: too many parameters");
}
LOG(DEBUG) << "Parse parameters: \"" << parameters << "\"";
Parser parser(parameters);
while (!parser.data().empty()) {
auto key_value = parser.read_till_nofail('&');
parser.skip_nofail('&');
Parser kv_parser(key_value);
auto key = url_decode_inplace(kv_parser.read_till_nofail('='), true);
kv_parser.skip_nofail('=');
auto value = url_decode_inplace(kv_parser.data(), true);
query_->args_.emplace_back(key, value);
}
CHECK(parser.status().is_ok());
return Status::OK();
}
Status HttpReader::parse_json_parameters(MutableSlice parameters) {
if (parameters.empty()) {
return Status::OK();
}
total_parameters_length_ += parameters.size();
if (total_parameters_length_ > MAX_TOTAL_PARAMETERS_LENGTH) {
return Status::Error(413, "Request Entity Too Large: too many parameters");
}
LOG(DEBUG) << "Parse JSON parameters: \"" << parameters << "\"";
Parser parser(parameters);
parser.skip_whitespaces();
if (parser.peek_char() == '"') {
auto r_value = json_string_decode(parser);
if (r_value.is_error()) {
return Status::Error(400, PSLICE() << "Bad Request: can't parse string content: " << r_value.error().message());
}
if (!parser.empty()) {
return Status::Error(400, "Bad Request: extra data after string");
}
query_->container_.emplace_back("content");
query_->args_.emplace_back(query_->container_.back().as_mutable_slice(), r_value.move_as_ok());
return Status::OK();
}
parser.skip('{');
if (parser.status().is_error()) {
return Status::Error(400, "Bad Request: JSON object expected");
}
while (true) {
parser.skip_whitespaces();
if (parser.try_skip('}')) {
parser.skip_whitespaces();
if (parser.empty()) {
return Status::OK();
}
return Status::Error(400, "Bad Request: unexpected data after object end");
}
if (parser.empty()) {
return Status::Error(400, "Bad Request: expected parameter name");
}
auto r_key = json_string_decode(parser);
if (r_key.is_error()) {
return Status::Error(400, PSLICE() << "Bad Request: can't parse parameter name: " << r_key.error().message());
}
parser.skip_whitespaces();
if (!parser.try_skip(':')) {
return Status::Error(400, "Bad Request: can't parse object, ':' expected");
}
parser.skip_whitespaces();
auto r_value = [&]() -> Result<MutableSlice> {
if (parser.peek_char() == '"') {
return json_string_decode(parser);
} else {
const int32 DEFAULT_MAX_DEPTH = 100;
auto begin = parser.ptr();
auto result = do_json_skip(parser, DEFAULT_MAX_DEPTH);
if (result.is_ok()) {
return MutableSlice(begin, parser.ptr());
} else {
return result.move_as_error();
}
}
}();
if (r_value.is_error()) {
return Status::Error(400, PSLICE() << "Bad Request: can't parse parameter value: " << r_value.error().message());
}
query_->args_.emplace_back(r_key.move_as_ok(), r_value.move_as_ok());
parser.skip_whitespaces();
if (parser.peek_char() != '}' && !parser.try_skip(',')) {
return Status::Error(400, "Bad Request: expected next field or object end");
}
}
UNREACHABLE();
return Status::OK();
}
Status HttpReader::parse_http_version(Slice version) {
if (version == "HTTP/1.1") {
query_->keep_alive_ = true;
} else if (version == "HTTP/1.0") {
query_->keep_alive_ = false;
} else {
LOG(INFO) << "Unsupported HTTP version: " << version;
return Status::Error(505, "HTTP Version Not Supported");
}
return Status::OK();
}
Status HttpReader::parse_head(MutableSlice head) {
Parser parser(head);
Slice type = parser.read_till(' ');
parser.skip(' ');
// GET POST HTTP/1.1
if (type == "GET") {
query_->type_ = HttpQuery::Type::Get;
} else if (type == "POST") {
query_->type_ = HttpQuery::Type::Post;
} else if (type.size() >= 4 && type.substr(0, 4) == "HTTP") {
TRY_STATUS(parse_http_version(type));
query_->type_ = HttpQuery::Type::Response;
} else {
LOG(INFO) << "Not Implemented " << tag("type", type) << tag("head", head);
return Status::Error(501, "Not Implemented");
}
query_->args_.clear();
if (query_->type_ == HttpQuery::Type::Response) {
query_->code_ = to_integer<int32>(parser.read_till(' '));
parser.skip(' ');
query_->reason_ = parser.read_till('\r');
LOG(DEBUG) << "Receive HTTP response " << query_->code_ << " " << query_->reason_;
} else {
auto url_version = parser.read_till('\r');
auto space_pos = url_version.rfind(' ');
if (space_pos == static_cast<size_t>(-1)) {
return Status::Error(400, "Bad Request: wrong request line");
}
TRY_STATUS(parse_url(url_version.substr(0, space_pos)));
TRY_STATUS(parse_http_version(url_version.substr(space_pos + 1)));
}
parser.skip('\r');
parser.skip('\n');
content_length_ = -1;
content_type_ = Slice("application/octet-stream");
content_type_lowercased_ = content_type_.str();
transfer_encoding_ = Slice();
content_encoding_ = Slice();
query_->headers_.clear();
query_->files_.clear();
query_->content_ = MutableSlice();
while (parser.status().is_ok() && !parser.data().empty()) {
MutableSlice header_name = parser.read_till(':');
parser.skip(':');
char *header_value_begin = parser.ptr();
char *header_value_end;
do {
parser.read_till('\r');
header_value_end = parser.ptr();
parser.skip('\r');
parser.skip('\n');
} while (parser.status().is_ok() && (parser.peek_char() == ' ' || parser.peek_char() == '\t'));
process_header(header_name, {header_value_begin, header_value_end});
}
return parser.status().is_ok() ? Status::OK() : Status::Error(400, "Bad Request");
}
Status HttpReader::open_temp_file(CSlice desired_file_name) {
CHECK(temp_file_.empty());
auto tmp_dir = get_temporary_dir();
if (tmp_dir.empty()) {
return Status::Error("Can't find temporary directory");
}
TRY_RESULT(dir, realpath(tmp_dir, true));
CHECK(!dir.empty());
auto first_try = try_open_temp_file(dir, desired_file_name);
if (first_try.is_ok()) {
return Status::OK();
}
// Creation of new file with desired name has failed. Trying to create unique directory for it
TRY_RESULT(directory, mkdtemp(dir, TEMP_DIRECTORY_PREFIX));
auto second_try = try_open_temp_file(directory, desired_file_name);
if (second_try.is_ok()) {
return Status::OK();
}
auto third_try = try_open_temp_file(directory, "file");
if (third_try.is_ok()) {
return Status::OK();
}
rmdir(directory).ignore();
LOG(WARNING) << "Failed to create temporary file \"" << desired_file_name << "\": " << second_try.error();
return second_try.move_as_error();
}
Status HttpReader::try_open_temp_file(Slice directory_name, CSlice desired_file_name) {
CHECK(temp_file_.empty());
CHECK(!directory_name.empty());
string file_name = clean_filename(desired_file_name);
if (file_name.empty()) {
file_name = "file";
}
temp_file_name_.clear();
temp_file_name_.reserve(directory_name.size() + 1 + file_name.size());
temp_file_name_.append(directory_name.data(), directory_name.size());
if (temp_file_name_.back() != TD_DIR_SLASH) {
temp_file_name_ += TD_DIR_SLASH;
}
temp_file_name_.append(file_name.data(), file_name.size());
TRY_RESULT(opened_file, FileFd::open(temp_file_name_, FileFd::Write | FileFd::CreateNew, 0640));
file_size_ = 0;
temp_file_ = std::move(opened_file);
LOG(DEBUG) << "Created temporary file " << temp_file_name_;
return Status::OK();
}
Status HttpReader::save_file_part(BufferSlice &&file_part) {
file_size_ += narrow_cast<int64>(file_part.size());
if (file_size_ > MAX_FILE_SIZE) {
clean_temporary_file();
return Status::Error(
413, PSLICE() << "Request Entity Too Large: file of size " << file_size_ << " is too big to be uploaded");
}
LOG(DEBUG) << "Save file part of size " << file_part.size() << " to file " << temp_file_name_;
auto result_written = temp_file_.write(file_part.as_slice());
if (result_written.is_error() || result_written.ok() != file_part.size()) {
clean_temporary_file();
return Status::Error(500, "Internal Server Error: can't upload the file");
}
return Status::OK();
}
void HttpReader::clean_temporary_file() {
string file_name = temp_file_name_;
close_temp_file();
delete_temp_file(file_name);
}
void HttpReader::close_temp_file() {
LOG(DEBUG) << "Close temporary file " << temp_file_name_;
CHECK(!temp_file_.empty());
temp_file_.close();
CHECK(temp_file_.empty());
temp_file_name_.clear();
}
void HttpReader::delete_temp_file(CSlice file_name) {
CHECK(!file_name.empty());
LOG(DEBUG) << "Unlink temporary file " << file_name;
unlink(file_name).ignore();
PathView path_view(file_name);
Slice parent = path_view.parent_dir();
const size_t prefix_length = std::strlen(TEMP_DIRECTORY_PREFIX);
if (parent.size() >= prefix_length + 7 &&
parent.substr(parent.size() - prefix_length - 7, prefix_length) == TEMP_DIRECTORY_PREFIX) {
LOG(DEBUG) << "Unlink temporary directory " << parent;
rmdir(PSLICE() << Slice(parent.data(), parent.size() - 1)).ignore();
}
}
} // namespace td

View File

@@ -0,0 +1,116 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/net/HttpChunkedByteFlow.h"
#include "td/net/HttpContentLengthByteFlow.h"
#include "td/net/HttpQuery.h"
#include "td/utils/buffer.h"
#include "td/utils/ByteFlow.h"
#include "td/utils/common.h"
#include "td/utils/GzipByteFlow.h"
#include "td/utils/port/FileFd.h"
#include "td/utils/Slice.h"
#include "td/utils/Status.h"
#include "td/utils/StringBuilder.h"
#include <limits>
namespace td {
class HttpReader {
public:
void init(ChainBufferReader *input, size_t max_post_size = std::numeric_limits<size_t>::max(),
size_t max_files = 100);
Result<size_t> read_next(HttpQuery *query, bool can_be_slow = true) TD_WARN_UNUSED_RESULT; // TODO move query to init
HttpReader() = default;
HttpReader(const HttpReader &) = delete;
HttpReader &operator=(const HttpReader &) = delete;
HttpReader(HttpReader &&) = delete;
HttpReader &operator=(HttpReader &&) = delete;
~HttpReader() {
if (!temp_file_.empty()) {
clean_temporary_file();
}
}
static void delete_temp_file(CSlice file_name);
private:
size_t max_post_size_ = 0;
size_t max_files_ = 0;
enum class State { ReadHeaders, ReadContent, ReadContentToFile, ReadArgs, ReadMultipartFormData };
State state_ = State::ReadHeaders;
size_t headers_read_length_ = 0;
int64 content_length_ = -1;
ChainBufferReader *input_ = nullptr;
ByteFlowSource flow_source_;
HttpChunkedByteFlow chunked_flow_;
GzipByteFlow gzip_flow_;
HttpContentLengthByteFlow content_length_flow_;
ByteFlowSink flow_sink_;
ChainBufferReader *content_ = nullptr;
HttpQuery *query_ = nullptr;
Slice transfer_encoding_;
Slice content_encoding_;
Slice content_type_;
string content_type_lowercased_;
size_t total_parameters_length_ = 0;
size_t total_headers_length_ = 0;
string boundary_;
size_t form_data_read_length_ = 0;
size_t form_data_skipped_length_ = 0;
enum class FormDataParseState : int32 {
SkipPrologue,
ReadPartHeaders,
ReadPartValue,
ReadFile,
CheckForLastBoundary,
SkipEpilogue
};
FormDataParseState form_data_parse_state_ = FormDataParseState::SkipPrologue;
MutableSlice field_name_;
string file_field_name_;
string field_content_type_;
string file_name_;
bool has_file_name_ = false;
FileFd temp_file_;
string temp_file_name_;
int64 file_size_ = 0;
Result<size_t> do_read_next(bool can_be_slow);
Result<size_t> split_header() TD_WARN_UNUSED_RESULT;
void process_header(MutableSlice header_name, MutableSlice header_value);
Result<bool> parse_multipart_form_data(bool can_be_slow) TD_WARN_UNUSED_RESULT;
Status parse_url(MutableSlice url) TD_WARN_UNUSED_RESULT;
Status parse_parameters(MutableSlice parameters) TD_WARN_UNUSED_RESULT;
Status parse_json_parameters(MutableSlice parameters) TD_WARN_UNUSED_RESULT;
Status parse_http_version(Slice version) TD_WARN_UNUSED_RESULT;
Status parse_head(MutableSlice head) TD_WARN_UNUSED_RESULT;
Status open_temp_file(CSlice desired_file_name) TD_WARN_UNUSED_RESULT;
Status try_open_temp_file(Slice directory_name, CSlice desired_file_name) TD_WARN_UNUSED_RESULT;
Status save_file_part(BufferSlice &&file_part) TD_WARN_UNUSED_RESULT;
void close_temp_file();
void clean_temporary_file();
static constexpr size_t MAX_CONTENT_SIZE = std::numeric_limits<uint32>::max(); // Some reasonable limit
static constexpr size_t MAX_TOTAL_PARAMETERS_LENGTH = 1 << 20; // Some reasonable limit
static constexpr size_t MAX_TOTAL_HEADERS_LENGTH = 1 << 18; // Some reasonable limit
static constexpr size_t MAX_BOUNDARY_LENGTH = 70; // As defined by RFC1341
static constexpr int64 MAX_FILE_SIZE = static_cast<int64>(4000) << 20; // Telegram server file size limit
static constexpr const char TEMP_DIRECTORY_PREFIX[] = "tdlib-server-tmp";
};
} // namespace td

144
td/tdnet/td/net/NetStats.h Normal file
View File

@@ -0,0 +1,144 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/actor/SchedulerLocalStorage.h"
#include "td/utils/common.h"
#include "td/utils/format.h"
#include "td/utils/StringBuilder.h"
#include "td/utils/Time.h"
#include <atomic>
#include <memory>
namespace td {
class NetStatsCallback {
public:
virtual void on_read(uint64 bytes) = 0;
virtual void on_write(uint64 bytes) = 0;
NetStatsCallback() = default;
NetStatsCallback(const NetStatsCallback &) = delete;
NetStatsCallback &operator=(const NetStatsCallback &) = delete;
virtual ~NetStatsCallback() = default;
};
struct NetStatsData {
uint64 read_size = 0;
uint64 write_size = 0;
uint64 count = 0;
double duration = 0;
};
inline NetStatsData operator+(const NetStatsData &a, const NetStatsData &b) {
NetStatsData res;
res.read_size = a.read_size + b.read_size;
res.write_size = a.write_size + b.write_size;
res.count = a.count + b.count;
res.duration = a.duration + b.duration;
return res;
}
inline NetStatsData operator-(const NetStatsData &a, const NetStatsData &b) {
NetStatsData res;
CHECK(a.read_size >= b.read_size);
res.read_size = a.read_size - b.read_size;
CHECK(a.write_size >= b.write_size);
res.write_size = a.write_size - b.write_size;
CHECK(a.count >= b.count);
res.count = a.count - b.count;
CHECK(a.duration >= b.duration);
res.duration = a.duration - b.duration;
return res;
}
inline StringBuilder &operator<<(StringBuilder &sb, const NetStatsData &data) {
return sb << tag("Rx size", format::as_size(data.read_size)) << tag("Tx size", format::as_size(data.write_size))
<< tag("count", data.count) << tag("duration", format::as_time(data.duration));
}
class NetStats {
public:
class Callback {
public:
virtual void on_stats_updated() = 0;
Callback() = default;
Callback(const Callback &) = delete;
Callback &operator=(const Callback &) = delete;
virtual ~Callback() = default;
};
std::shared_ptr<NetStatsCallback> get_callback() const {
return impl_;
}
NetStatsData get_stats() const {
return impl_->get_stats();
}
// do it before get_callback
void set_callback(unique_ptr<Callback> callback) {
impl_->set_callback(std::move(callback));
}
private:
class Impl final : public NetStatsCallback {
public:
NetStatsData get_stats() const {
NetStatsData res;
local_net_stats_.for_each([&](auto &stats) {
res.read_size += stats.read_size.load(std::memory_order_relaxed);
res.write_size += stats.write_size.load(std::memory_order_relaxed);
});
return res;
}
void set_callback(unique_ptr<Callback> callback) {
callback_ = std::move(callback);
}
private:
struct LocalNetStats {
double last_update = 0;
uint64 unsync_size = 0;
std::atomic<uint64> read_size{0};
std::atomic<uint64> write_size{0};
};
SchedulerLocalStorage<LocalNetStats> local_net_stats_;
unique_ptr<Callback> callback_;
void on_read(uint64 size) final {
auto &stats = local_net_stats_.get();
stats.read_size.fetch_add(size, std::memory_order_relaxed);
on_change(stats, size);
}
void on_write(uint64 size) final {
auto &stats = local_net_stats_.get();
stats.write_size.fetch_add(size, std::memory_order_relaxed);
on_change(stats, size);
}
void on_change(LocalNetStats &stats, uint64 size) {
stats.unsync_size += size;
auto now = Time::now();
if (stats.unsync_size > 10000 || now - stats.last_update > 300) {
stats.unsync_size = 0;
stats.last_update = now;
callback_->on_stats_updated();
}
}
};
std::shared_ptr<Impl> impl_{std::make_shared<Impl>()};
};
} // namespace td

190
td/tdnet/td/net/Socks5.cpp Normal file
View File

@@ -0,0 +1,190 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/Socks5.h"
#include "td/utils/common.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/Slice.h"
#include "td/utils/SliceBuilder.h"
namespace td {
void Socks5::send_greeting() {
VLOG(proxy) << "Send greeting to proxy";
CHECK(state_ == State::SendGreeting);
state_ = State::WaitGreetingResponse;
string greeting;
greeting += '\x05';
bool use_username = !username_.empty();
char authentication_count = use_username ? '\x02' : '\x01';
greeting += authentication_count;
greeting += '\0';
if (use_username) {
greeting += '\x02';
}
fd_.output_buffer().append(greeting);
}
Status Socks5::wait_greeting_response() {
auto &buf = fd_.input_buffer();
VLOG(proxy) << "Receive greeting response of size " << buf.size();
if (buf.size() < 2) {
return Status::OK();
}
auto buffer_slice = buf.read_as_buffer_slice(2);
auto slice = buffer_slice.as_slice();
if (slice[0] != '\x05') {
return Status::Error(PSLICE() << "Unsupported socks protocol version " << static_cast<int>(slice[0]));
}
auto authentication_method = slice[1];
if (authentication_method == '\0') {
send_ip_address();
return Status::OK();
}
if (authentication_method == '\x02') {
return send_username_password();
}
return Status::Error("Unsupported authentication mode");
}
Status Socks5::send_username_password() {
VLOG(proxy) << "Send username and password";
if (username_.size() >= 128) {
return Status::Error("Username is too long");
}
if (password_.size() >= 128) {
return Status::Error("Password is too long");
}
string request;
request += '\x01';
request += narrow_cast<char>(username_.size());
request += username_;
request += narrow_cast<char>(password_.size());
request += password_;
fd_.output_buffer().append(request);
state_ = State::WaitPasswordResponse;
return Status::OK();
}
Status Socks5::wait_password_response() {
auto &buf = fd_.input_buffer();
VLOG(proxy) << "Receive password response of size " << buf.size();
if (buf.size() < 2) {
return Status::OK();
}
auto buffer_slice = buf.read_as_buffer_slice(2);
auto slice = buffer_slice.as_slice();
if (slice[0] != '\x01') {
return Status::Error(PSLICE() << "Unsupported socks subnegotiation protocol version "
<< static_cast<int>(slice[0]));
}
if (slice[1] != '\x00') {
return Status::Error("Wrong username or password");
}
send_ip_address();
return Status::OK();
}
void Socks5::send_ip_address() {
VLOG(proxy) << "Send IP address";
callback_->on_connected();
string request;
request += '\x05';
request += '\x01';
request += '\x00';
if (ip_address_.is_ipv4()) {
request += '\x01';
auto ipv4 = ntohl(ip_address_.get_ipv4());
request += static_cast<char>(ipv4 & 255);
request += static_cast<char>((ipv4 >> 8) & 255);
request += static_cast<char>((ipv4 >> 16) & 255);
request += static_cast<char>((ipv4 >> 24) & 255);
} else {
request += '\x04';
request += ip_address_.get_ipv6();
}
auto port = ip_address_.get_port();
request += static_cast<char>((port >> 8) & 255);
request += static_cast<char>(port & 255);
fd_.output_buffer().append(request);
state_ = State::WaitIpAddressResponse;
}
Status Socks5::wait_ip_address_response() {
CHECK(state_ == State::WaitIpAddressResponse);
auto it = fd_.input_buffer().clone();
VLOG(proxy) << "Receive IP address response of size " << it.size();
if (it.size() < 4) {
return Status::OK();
}
char c;
MutableSlice c_slice(&c, 1);
it.advance(1, c_slice);
if (c != '\x05') {
return Status::Error("Invalid response");
}
it.advance(1, c_slice);
if (c != '\0') {
return Status::Error(PSLICE() << "Receive error code " << static_cast<int32>(c) << " from server");
}
it.advance(1, c_slice);
if (c != '\0') {
return Status::Error("Byte must be zero");
}
it.advance(1, c_slice);
size_t total_size = 6;
if (c == '\x01') {
if (it.size() < 4) {
return Status::OK();
}
it.advance(4);
total_size += 4;
} else if (c == '\x04') {
if (it.size() < 16) {
return Status::OK();
}
it.advance(16);
total_size += 16;
} else {
return Status::Error("Invalid response");
}
if (it.size() < 2) {
return Status::OK();
}
it.advance(2);
fd_.input_buffer().advance(total_size);
stop();
return Status::OK();
}
Status Socks5::loop_impl() {
switch (state_) {
case State::SendGreeting:
send_greeting();
break;
case State::WaitGreetingResponse:
TRY_STATUS(wait_greeting_response());
break;
case State::WaitPasswordResponse:
TRY_STATUS(wait_password_response());
break;
case State::WaitIpAddressResponse:
TRY_STATUS(wait_ip_address_response());
break;
default:
UNREACHABLE();
}
return Status::OK();
}
} // namespace td

39
td/tdnet/td/net/Socks5.h Normal file
View File

@@ -0,0 +1,39 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/net/TransparentProxy.h"
#include "td/utils/Status.h"
namespace td {
class Socks5 final : public TransparentProxy {
public:
using TransparentProxy::TransparentProxy;
private:
enum class State {
SendGreeting,
WaitGreetingResponse,
WaitPasswordResponse,
WaitIpAddressResponse
} state_ = State::SendGreeting;
void send_greeting();
Status wait_greeting_response();
Status send_username_password();
Status wait_password_response();
void send_ip_address();
Status wait_ip_address_response();
Status loop_impl() final;
};
} // namespace td

370
td/tdnet/td/net/SslCtx.cpp Normal file
View File

@@ -0,0 +1,370 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/SslCtx.h"
#include "td/utils/common.h"
#include "td/utils/crypto.h"
#include "td/utils/FlatHashMap.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/port/path.h"
#include "td/utils/port/wstring_convert.h"
#include "td/utils/ScopeGuard.h"
#include "td/utils/SliceBuilder.h"
#include "td/utils/Time.h"
#if !TD_EMSCRIPTEN
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <openssl/x509.h>
#include <openssl/x509_vfy.h>
#include <cstring>
#include <memory>
#include <mutex>
#if TD_PORT_WINDOWS
#include <wincrypt.h>
#endif
namespace td {
namespace detail {
namespace {
int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) {
if (!preverify_ok) {
char buf[256];
X509_NAME_oneline(X509_get_subject_name(X509_STORE_CTX_get_current_cert(ctx)), buf, 256);
int err = X509_STORE_CTX_get_error(ctx);
auto warning = PSTRING() << "verify error:num=" << err << ":" << X509_verify_cert_error_string(err)
<< ":depth=" << X509_STORE_CTX_get_error_depth(ctx) << ":" << Slice(buf, std::strlen(buf));
double now = Time::now();
static std::mutex warning_mutex;
{
std::lock_guard<std::mutex> lock(warning_mutex);
static FlatHashMap<string, double> next_warning_time;
double &next = next_warning_time[warning];
if (next <= now) {
next = now + 300; // one warning per 5 minutes
LOG(WARNING) << warning;
}
}
}
return preverify_ok;
}
X509_STORE *load_system_certificate_store() {
int32 cert_count = 0;
int32 file_count = 0;
LOG(DEBUG) << "Begin to load system certificate store";
SCOPE_EXIT {
LOG(DEBUG) << "End to load " << cert_count << " certificates from " << file_count << " files from system store";
if (ERR_peek_error() != 0) {
auto error = create_openssl_error(-22, "Have unprocessed errors");
LOG(INFO) << error;
}
};
#if TD_PORT_WINDOWS
auto flags = CERT_STORE_OPEN_EXISTING_FLAG | CERT_STORE_READONLY_FLAG | CERT_SYSTEM_STORE_CURRENT_USER;
HCERTSTORE system_store =
CertOpenStore(CERT_STORE_PROV_SYSTEM_W, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, HCRYPTPROV_LEGACY(), flags,
static_cast<const void *>(to_wstring("ROOT").ok().c_str()));
if (!system_store) {
return nullptr;
}
X509_STORE *store = X509_STORE_new();
if (store == nullptr) {
return nullptr;
}
for (PCCERT_CONTEXT cert_context = CertEnumCertificatesInStore(system_store, nullptr); cert_context != nullptr;
cert_context = CertEnumCertificatesInStore(system_store, cert_context)) {
const unsigned char *in = cert_context->pbCertEncoded;
X509 *x509 = d2i_X509(nullptr, &in, static_cast<long>(cert_context->cbCertEncoded));
if (x509 != nullptr) {
if (X509_STORE_add_cert(store, x509) != 1) {
auto error_code = ERR_peek_error();
auto error = create_openssl_error(-20, "Failed to add certificate");
if (ERR_GET_REASON(error_code) != X509_R_CERT_ALREADY_IN_HASH_TABLE) {
LOG(ERROR) << error;
} else {
LOG(INFO) << error;
}
} else {
cert_count++;
}
X509_free(x509);
} else {
LOG(ERROR) << create_openssl_error(-21, "Failed to load X509 certificate");
}
}
CertCloseStore(system_store, 0);
#else
X509_STORE *store = X509_STORE_new();
if (store == nullptr) {
return nullptr;
}
auto add_file = [&](CSlice path) {
if (X509_STORE_load_locations(store, path.c_str(), nullptr) != 1) {
auto error = create_openssl_error(-20, "Failed to add certificate");
LOG(INFO) << path << ": " << error;
} else {
file_count++;
}
};
string default_cert_dir = X509_get_default_cert_dir();
for (auto cert_dir : full_split(default_cert_dir, ':')) {
walk_path(cert_dir, [&](CSlice path, WalkPath::Type type) {
if (type != WalkPath::Type::RegularFile && type != WalkPath::Type::Symlink) {
return type == WalkPath::Type::EnterDir && path != cert_dir ? WalkPath::Action::SkipDir
: WalkPath::Action::Continue;
}
add_file(path);
return WalkPath::Action::Continue;
}).ignore();
}
string default_cert_path = X509_get_default_cert_file();
if (!default_cert_path.empty()) {
add_file(default_cert_path);
}
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
auto objects = X509_STORE_get0_objects(store);
cert_count = objects == nullptr ? 0 : sk_X509_OBJECT_num(objects);
#else
cert_count = -1;
#endif
#endif
return store;
}
using SslCtxPtr = std::shared_ptr<SSL_CTX>;
Result<SslCtxPtr> do_create_ssl_ctx(CSlice cert_file, SslCtx::VerifyPeer verify_peer) {
auto ssl_method =
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
TLS_client_method();
#else
SSLv23_client_method();
#endif
if (ssl_method == nullptr) {
return create_openssl_error(-6, "Failed to create an SSL client method");
}
auto ssl_ctx = SSL_CTX_new(ssl_method);
if (!ssl_ctx) {
return create_openssl_error(-7, "Failed to create an SSL context");
}
auto ssl_ctx_ptr = SslCtxPtr(ssl_ctx, SSL_CTX_free);
long options = 0;
#ifdef SSL_OP_NO_SSLv2
options |= SSL_OP_NO_SSLv2;
#endif
#ifdef SSL_OP_NO_SSLv3
options |= SSL_OP_NO_SSLv3;
#endif
SSL_CTX_set_options(ssl_ctx, options);
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
SSL_CTX_set_min_proto_version(ssl_ctx, TLS1_VERSION);
#endif
SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
if (cert_file.empty()) {
auto *store = load_system_certificate_store();
if (store == nullptr) {
auto error = create_openssl_error(-8, "Failed to load system certificate store");
if (verify_peer == SslCtx::VerifyPeer::On) {
return std::move(error);
} else {
LOG(ERROR) << error;
}
} else {
SSL_CTX_set_cert_store(ssl_ctx, store);
}
} else {
if (SSL_CTX_load_verify_locations(ssl_ctx, cert_file.c_str(), nullptr) == 0) {
return create_openssl_error(-8, "Failed to set custom certificate file");
}
}
if (verify_peer == SslCtx::VerifyPeer::On) {
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, verify_callback);
constexpr int DEFAULT_VERIFY_DEPTH = 10;
SSL_CTX_set_verify_depth(ssl_ctx, DEFAULT_VERIFY_DEPTH);
} else {
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, nullptr);
}
string cipher_list;
if (SSL_CTX_set_cipher_list(ssl_ctx, cipher_list.empty() ? "DEFAULT" : cipher_list.c_str()) == 0) {
return create_openssl_error(-9, PSLICE() << "Failed to set cipher list \"" << cipher_list << '"');
}
return std::move(ssl_ctx_ptr);
}
Result<SslCtxPtr> get_default_ssl_ctx() {
static auto ctx = do_create_ssl_ctx(CSlice(), SslCtx::VerifyPeer::On);
if (ctx.is_error()) {
return ctx.error().clone();
}
return ctx.ok();
}
Result<SslCtxPtr> get_default_unverified_ssl_ctx() {
static auto ctx = do_create_ssl_ctx(CSlice(), SslCtx::VerifyPeer::Off);
if (ctx.is_error()) {
return ctx.error().clone();
}
return ctx.ok();
}
} // namespace
class SslCtxImpl {
public:
Status init(CSlice cert_file, SslCtx::VerifyPeer verify_peer) {
SslCtx::init_openssl();
clear_openssl_errors("Before SslCtx::init");
if (cert_file.empty()) {
if (verify_peer == SslCtx::VerifyPeer::On) {
TRY_RESULT_ASSIGN(ssl_ctx_ptr_, get_default_ssl_ctx());
} else {
TRY_RESULT_ASSIGN(ssl_ctx_ptr_, get_default_unverified_ssl_ctx());
}
return Status::OK();
}
auto start_time = Time::now();
auto r_ssl_ctx_ptr = do_create_ssl_ctx(cert_file, verify_peer);
auto elapsed_time = Time::now() - start_time;
if (elapsed_time >= 0.1) {
LOG(WARNING) << "SSL context creation took " << elapsed_time << " seconds";
}
if (r_ssl_ctx_ptr.is_error()) {
return r_ssl_ctx_ptr.move_as_error();
}
ssl_ctx_ptr_ = r_ssl_ctx_ptr.move_as_ok();
return Status::OK();
}
void *get_openssl_ctx() const {
return static_cast<void *>(ssl_ctx_ptr_.get());
}
private:
SslCtxPtr ssl_ctx_ptr_;
};
} // namespace detail
SslCtx::SslCtx() = default;
SslCtx::SslCtx(const SslCtx &other) {
if (other.impl_) {
impl_ = make_unique<detail::SslCtxImpl>(*other.impl_);
}
}
SslCtx &SslCtx::operator=(const SslCtx &other) {
if (other.impl_) {
impl_ = make_unique<detail::SslCtxImpl>(*other.impl_);
} else {
impl_ = nullptr;
}
return *this;
}
SslCtx::SslCtx(SslCtx &&) noexcept = default;
SslCtx &SslCtx::operator=(SslCtx &&) noexcept = default;
SslCtx::~SslCtx() = default;
void SslCtx::init_openssl() {
static bool is_inited = [] {
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
return OPENSSL_init_ssl(0, nullptr) != 0;
#else
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
return OpenSSL_add_ssl_algorithms() != 0;
#endif
}();
CHECK(is_inited);
}
Result<SslCtx> SslCtx::create(CSlice cert_file, VerifyPeer verify_peer) {
auto impl = make_unique<detail::SslCtxImpl>();
TRY_STATUS(impl->init(cert_file, verify_peer));
return SslCtx(std::move(impl));
}
void *SslCtx::get_openssl_ctx() const {
return impl_ == nullptr ? nullptr : impl_->get_openssl_ctx();
}
SslCtx::SslCtx(unique_ptr<detail::SslCtxImpl> impl) : impl_(std::move(impl)) {
}
} // namespace td
#else
namespace td {
namespace detail {
class SslCtxImpl {};
} // namespace detail
SslCtx::SslCtx() = default;
SslCtx::SslCtx(const SslCtx &other) {
UNREACHABLE();
}
SslCtx &SslCtx::operator=(const SslCtx &other) {
UNREACHABLE();
return *this;
}
SslCtx::SslCtx(SslCtx &&) noexcept = default;
SslCtx &SslCtx::operator=(SslCtx &&) noexcept = default;
SslCtx::~SslCtx() = default;
void SslCtx::init_openssl() {
}
Result<SslCtx> SslCtx::create(CSlice cert_file, VerifyPeer verify_peer) {
return Status::Error("Not supported in Emscripten");
}
void *SslCtx::get_openssl_ctx() const {
return nullptr;
}
SslCtx::SslCtx(unique_ptr<detail::SslCtxImpl> impl) : impl_(std::move(impl)) {
}
} // namespace td
#endif

45
td/tdnet/td/net/SslCtx.h Normal file
View File

@@ -0,0 +1,45 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/utils/Slice.h"
#include "td/utils/Status.h"
namespace td {
namespace detail {
class SslCtxImpl;
} // namespace detail
class SslCtx {
public:
SslCtx();
SslCtx(const SslCtx &other);
SslCtx &operator=(const SslCtx &other);
SslCtx(SslCtx &&) noexcept;
SslCtx &operator=(SslCtx &&) noexcept;
~SslCtx();
static void init_openssl();
enum class VerifyPeer { On, Off };
static Result<SslCtx> create(CSlice cert_file, VerifyPeer verify_peer);
void *get_openssl_ctx() const;
explicit operator bool() const noexcept {
return static_cast<bool>(impl_);
}
private:
unique_ptr<detail::SslCtxImpl> impl_;
explicit SslCtx(unique_ptr<detail::SslCtxImpl> impl);
};
} // namespace td

View File

@@ -0,0 +1,423 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/SslStream.h"
#if !TD_EMSCRIPTEN
#include "td/utils/common.h"
#include "td/utils/crypto.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/port/IPAddress.h"
#include "td/utils/Status.h"
#include "td/utils/Time.h"
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <openssl/x509.h>
#include <cstring>
#include <memory>
namespace td {
namespace detail {
namespace {
#if OPENSSL_VERSION_NUMBER < 0x10100000L
void *BIO_get_data(BIO *b) {
return b->ptr;
}
void BIO_set_data(BIO *b, void *ptr) {
b->ptr = ptr;
}
void BIO_set_init(BIO *b, int init) {
b->init = init;
}
int BIO_get_new_index() {
return 0;
}
BIO_METHOD *BIO_meth_new(int type, const char *name) {
auto res = new BIO_METHOD();
std::memset(res, 0, sizeof(*res));
return res;
}
int BIO_meth_set_write(BIO_METHOD *biom, int (*bwrite)(BIO *, const char *, int)) {
biom->bwrite = bwrite;
return 1;
}
int BIO_meth_set_read(BIO_METHOD *biom, int (*bread)(BIO *, char *, int)) {
biom->bread = bread;
return 1;
}
int BIO_meth_set_ctrl(BIO_METHOD *biom, long (*ctrl)(BIO *, int, long, void *)) {
biom->ctrl = ctrl;
return 1;
}
int BIO_meth_set_create(BIO_METHOD *biom, int (*create)(BIO *)) {
biom->create = create;
return 1;
}
int BIO_meth_set_destroy(BIO_METHOD *biom, int (*destroy)(BIO *)) {
biom->destroy = destroy;
return 1;
}
#endif
int strm_create(BIO *b) {
BIO_set_init(b, 1);
return 1;
}
int strm_destroy(BIO *b) {
return 1;
}
int strm_read(BIO *b, char *buf, int len);
int strm_write(BIO *b, const char *buf, int len);
long strm_ctrl(BIO *b, int cmd, long num, void *ptr) {
switch (cmd) {
case BIO_CTRL_FLUSH:
return 1;
case BIO_CTRL_PUSH:
case BIO_CTRL_POP:
return 0;
#if defined(BIO_CTRL_GET_KTLS_SEND)
case BIO_CTRL_GET_KTLS_SEND:
return 0;
#endif
#if defined(BIO_CTRL_GET_KTLS_RECV)
case BIO_CTRL_GET_KTLS_RECV:
return 0;
#endif
default:
LOG(FATAL) << b << " " << cmd << " " << num << " " << ptr;
}
return 1;
}
BIO_METHOD *BIO_s_sslstream() {
static BIO_METHOD *result = [] {
BIO_METHOD *res = BIO_meth_new(BIO_get_new_index(), "td::SslStream helper bio");
BIO_meth_set_write(res, strm_write);
BIO_meth_set_read(res, strm_read);
BIO_meth_set_create(res, strm_create);
BIO_meth_set_destroy(res, strm_destroy);
BIO_meth_set_ctrl(res, strm_ctrl);
return res;
}();
return result;
}
struct SslHandleDeleter {
void operator()(SSL *ssl_handle) {
auto start_time = Time::now();
if (SSL_is_init_finished(ssl_handle)) {
clear_openssl_errors("Before SSL_shutdown");
SSL_set_quiet_shutdown(ssl_handle, 1);
SSL_shutdown(ssl_handle);
clear_openssl_errors("After SSL_shutdown");
}
SSL_free(ssl_handle);
auto elapsed_time = Time::now() - start_time;
if (elapsed_time >= 0.1) {
LOG(WARNING) << "SSL_free took " << elapsed_time << " seconds";
}
}
};
using SslHandle = std::unique_ptr<SSL, SslHandleDeleter>;
} // namespace
class SslStreamImpl {
public:
Status init(CSlice host, SslCtx ssl_ctx, bool check_ip_address_as_host) {
if (!ssl_ctx) {
return Status::Error("Invalid SSL context provided");
}
clear_openssl_errors("Before SslFd::init");
auto ssl_handle = SslHandle(SSL_new(static_cast<SSL_CTX *>(ssl_ctx.get_openssl_ctx())));
if (!ssl_handle) {
return create_openssl_error(-13, "Failed to create an SSL handle");
}
auto r_ip_address = IPAddress::get_ip_address(host);
#if OPENSSL_VERSION_NUMBER >= 0x10002000L
X509_VERIFY_PARAM *param = SSL_get0_param(ssl_handle.get());
X509_VERIFY_PARAM_set_hostflags(param, 0);
if (r_ip_address.is_ok() && !check_ip_address_as_host) {
LOG(DEBUG) << "Set verification IP address to " << r_ip_address.ok().get_ip_str();
X509_VERIFY_PARAM_set1_ip_asc(param, r_ip_address.ok().get_ip_str().c_str());
} else {
LOG(DEBUG) << "Set verification host to " << host;
X509_VERIFY_PARAM_set1_host(param, host.c_str(), 0);
}
#else
#warning DANGEROUS! HTTPS HOST WILL NOT BE CHECKED. INSTALL OPENSSL >= 1.0.2 OR IMPLEMENT HTTPS HOST CHECK MANUALLY
#endif
auto *bio = BIO_new(BIO_s_sslstream());
BIO_set_data(bio, static_cast<void *>(this));
SSL_set_bio(ssl_handle.get(), bio, bio);
#if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT)
if (r_ip_address.is_error()) { // IP address must not be sent as SNI
LOG(DEBUG) << "Set SNI host name to " << host;
auto host_str = host.str();
SSL_set_tlsext_host_name(ssl_handle.get(), MutableCSlice(host_str).begin());
}
#endif
SSL_set_connect_state(ssl_handle.get());
ssl_handle_ = std::move(ssl_handle);
return Status::OK();
}
ByteFlowInterface &read_byte_flow() {
return read_flow_;
}
ByteFlowInterface &write_byte_flow() {
return write_flow_;
}
size_t flow_read(MutableSlice slice) {
return read_flow_.read(slice);
}
size_t flow_write(Slice slice) {
return write_flow_.write(slice);
}
private:
SslHandle ssl_handle_;
friend class SslReadByteFlow;
friend class SslWriteByteFlow;
Result<size_t> write(Slice slice) {
clear_openssl_errors("Before SslFd::write");
auto start_time = Time::now();
auto size = SSL_write(ssl_handle_.get(), slice.data(), static_cast<int>(slice.size()));
auto elapsed_time = Time::now() - start_time;
if (elapsed_time >= 0.1) {
LOG(WARNING) << "SSL_write of size " << slice.size() << " took " << elapsed_time << " seconds and returned "
<< size << ' ' << SSL_get_error(ssl_handle_.get(), size);
}
if (size <= 0) {
return process_ssl_error(size);
}
return size;
}
Result<size_t> read(MutableSlice slice) {
clear_openssl_errors("Before SslFd::read");
auto start_time = Time::now();
auto size = SSL_read(ssl_handle_.get(), slice.data(), static_cast<int>(slice.size()));
auto elapsed_time = Time::now() - start_time;
if (elapsed_time >= 0.1) {
LOG(WARNING) << "SSL_read took " << elapsed_time << " seconds and returned " << size << ' '
<< SSL_get_error(ssl_handle_.get(), size);
}
if (size <= 0) {
return process_ssl_error(size);
}
return size;
}
class SslReadByteFlow final : public ByteFlowBase {
public:
explicit SslReadByteFlow(SslStreamImpl *stream) : stream_(stream) {
}
bool loop() final {
auto to_read = output_.prepare_append();
auto r_size = stream_->read(to_read);
if (r_size.is_error()) {
finish(r_size.move_as_error());
return false;
}
auto size = r_size.move_as_ok();
if (size == 0) {
return false;
}
output_.confirm_append(size);
return true;
}
size_t read(MutableSlice data) {
return input_->advance(min(data.size(), input_->size()), data);
}
private:
SslStreamImpl *stream_;
};
class SslWriteByteFlow final : public ByteFlowBase {
public:
explicit SslWriteByteFlow(SslStreamImpl *stream) : stream_(stream) {
}
bool loop() final {
auto to_write = input_->prepare_read();
auto r_size = stream_->write(to_write);
if (r_size.is_error()) {
finish(r_size.move_as_error());
return false;
}
auto size = r_size.move_as_ok();
if (size == 0) {
return false;
}
input_->confirm_read(size);
return true;
}
size_t write(Slice data) {
output_.append(data);
return data.size();
}
private:
SslStreamImpl *stream_;
};
SslReadByteFlow read_flow_{this};
SslWriteByteFlow write_flow_{this};
Result<size_t> process_ssl_error(int ret) {
auto os_error = OS_ERROR("SSL_ERROR_SYSCALL");
int error = SSL_get_error(ssl_handle_.get(), ret);
switch (error) {
case SSL_ERROR_NONE:
LOG(ERROR) << "SSL_get_error returned no error";
return 0;
case SSL_ERROR_ZERO_RETURN:
LOG(DEBUG) << "SSL_ZERO_RETURN";
return 0;
case SSL_ERROR_WANT_READ:
LOG(DEBUG) << "SSL_WANT_READ";
return 0;
case SSL_ERROR_WANT_WRITE:
LOG(DEBUG) << "SSL_WANT_WRITE";
return 0;
case SSL_ERROR_WANT_CONNECT:
case SSL_ERROR_WANT_ACCEPT:
case SSL_ERROR_WANT_X509_LOOKUP:
LOG(DEBUG) << "SSL: CONNECT ACCEPT LOOKUP";
return 0;
case SSL_ERROR_SYSCALL:
if (ERR_peek_error() == 0) {
if (os_error.code() != 0) {
LOG(DEBUG) << "SSL_ERROR_SYSCALL";
return std::move(os_error);
} else {
LOG(DEBUG) << "SSL_SYSCALL";
return 0;
}
}
/* fallthrough */
default:
LOG(DEBUG) << "SSL_ERROR Default";
return create_openssl_error(1, "SSL error ");
}
}
};
namespace {
int strm_read(BIO *b, char *buf, int len) {
auto *stream = static_cast<SslStreamImpl *>(BIO_get_data(b));
CHECK(stream != nullptr);
BIO_clear_retry_flags(b);
CHECK(buf != nullptr);
auto res = narrow_cast<int>(stream->flow_read(MutableSlice(buf, len)));
if (res == 0) {
BIO_set_retry_read(b);
return -1;
}
return res;
}
int strm_write(BIO *b, const char *buf, int len) {
auto *stream = static_cast<SslStreamImpl *>(BIO_get_data(b));
CHECK(stream != nullptr);
BIO_clear_retry_flags(b);
CHECK(buf != nullptr);
return narrow_cast<int>(stream->flow_write(Slice(buf, len)));
}
} // namespace
} // namespace detail
SslStream::SslStream() = default;
SslStream::SslStream(SslStream &&) noexcept = default;
SslStream &SslStream::operator=(SslStream &&) noexcept = default;
SslStream::~SslStream() = default;
Result<SslStream> SslStream::create(CSlice host, SslCtx ssl_ctx, bool use_ip_address_as_host) {
auto impl = make_unique<detail::SslStreamImpl>();
TRY_STATUS(impl->init(host, ssl_ctx, use_ip_address_as_host));
return SslStream(std::move(impl));
}
SslStream::SslStream(unique_ptr<detail::SslStreamImpl> impl) : impl_(std::move(impl)) {
}
ByteFlowInterface &SslStream::read_byte_flow() {
return impl_->read_byte_flow();
}
ByteFlowInterface &SslStream::write_byte_flow() {
return impl_->write_byte_flow();
}
size_t SslStream::flow_read(MutableSlice slice) {
return impl_->flow_read(slice);
}
size_t SslStream::flow_write(Slice slice) {
return impl_->flow_write(slice);
}
} // namespace td
#else
namespace td {
namespace detail {
class SslStreamImpl {};
} // namespace detail
SslStream::SslStream() = default;
SslStream::SslStream(SslStream &&) noexcept = default;
SslStream &SslStream::operator=(SslStream &&) noexcept = default;
SslStream::~SslStream() = default;
Result<SslStream> SslStream::create(CSlice host, SslCtx ssl_ctx, bool check_ip_address_as_host) {
return Status::Error("Not supported in Emscripten");
}
SslStream::SslStream(unique_ptr<detail::SslStreamImpl> impl) : impl_(std::move(impl)) {
}
ByteFlowInterface &SslStream::read_byte_flow() {
UNREACHABLE();
}
ByteFlowInterface &SslStream::write_byte_flow() {
UNREACHABLE();
}
size_t SslStream::flow_read(MutableSlice slice) {
UNREACHABLE();
}
size_t SslStream::flow_write(Slice slice) {
UNREACHABLE();
}
} // namespace td
#endif

View File

@@ -0,0 +1,46 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/net/SslCtx.h"
#include "td/utils/ByteFlow.h"
#include "td/utils/Slice.h"
#include "td/utils/Status.h"
namespace td {
namespace detail {
class SslStreamImpl;
} // namespace detail
class SslStream {
public:
SslStream();
SslStream(SslStream &&) noexcept;
SslStream &operator=(SslStream &&) noexcept;
~SslStream();
static Result<SslStream> create(CSlice host, SslCtx ssl_ctx, bool use_ip_address_as_host = false);
ByteFlowInterface &read_byte_flow();
ByteFlowInterface &write_byte_flow();
size_t flow_read(MutableSlice slice);
size_t flow_write(Slice slice);
explicit operator bool() const noexcept {
return static_cast<bool>(impl_);
}
private:
unique_ptr<detail::SslStreamImpl> impl_;
explicit SslStream(unique_ptr<detail::SslStreamImpl> impl);
};
} // namespace td

View File

@@ -0,0 +1,64 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/TcpListener.h"
#include "td/utils/logging.h"
#include "td/utils/port/detail/PollableFd.h"
namespace td {
TcpListener::TcpListener(int port, ActorShared<Callback> callback, Slice server_address)
: port_(port), callback_(std::move(callback)), server_address_(server_address.str()) {
}
void TcpListener::hangup() {
stop();
}
void TcpListener::start_up() {
auto r_socket = ServerSocketFd::open(port_, server_address_);
if (r_socket.is_error()) {
LOG(ERROR) << "Can't open server socket: " << r_socket.error();
set_timeout_in(5);
return;
}
server_fd_ = r_socket.move_as_ok();
Scheduler::subscribe(server_fd_.get_poll_info().extract_pollable_fd(this));
}
void TcpListener::tear_down() {
if (!server_fd_.empty()) {
Scheduler::unsubscribe_before_close(server_fd_.get_poll_info().get_pollable_fd_ref());
server_fd_.close();
}
}
void TcpListener::loop() {
if (server_fd_.empty()) {
start_up();
if (server_fd_.empty()) {
return;
}
}
sync_with_poll(server_fd_);
while (can_read_local(server_fd_)) {
auto r_socket_fd = server_fd_.accept();
if (r_socket_fd.is_error()) {
if (r_socket_fd.error().code() != -1) {
LOG(ERROR) << r_socket_fd.error();
}
continue;
}
send_closure(callback_, &Callback::accept, r_socket_fd.move_as_ok());
}
if (can_close_local(server_fd_)) {
stop();
}
}
} // namespace td

View File

@@ -0,0 +1,38 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/actor/actor.h"
#include "td/utils/common.h"
#include "td/utils/port/ServerSocketFd.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/Slice.h"
namespace td {
class TcpListener final : public Actor {
public:
class Callback : public Actor {
public:
virtual void accept(SocketFd fd) = 0;
};
TcpListener(int port, ActorShared<Callback> callback, Slice server_address = Slice("0.0.0.0"));
void hangup() final;
private:
int port_;
ServerSocketFd server_fd_;
ActorShared<Callback> callback_;
const string server_address_;
void start_up() final;
void tear_down() final;
void loop() final;
};
} // namespace td

View File

@@ -0,0 +1,84 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/TransparentProxy.h"
#include "td/utils/logging.h"
#include "td/utils/port/detail/PollableFd.h"
namespace td {
int VERBOSITY_NAME(proxy) = VERBOSITY_NAME(DEBUG);
TransparentProxy::TransparentProxy(SocketFd socket_fd, IPAddress ip_address, string username, string password,
unique_ptr<Callback> callback, ActorShared<> parent)
: fd_(std::move(socket_fd))
, ip_address_(std::move(ip_address))
, username_(std::move(username))
, password_(std::move(password))
, callback_(std::move(callback))
, parent_(std::move(parent)) {
}
void TransparentProxy::on_error(Status status) {
CHECK(status.is_error());
VLOG(proxy) << "Receive " << status;
if (callback_) {
callback_->set_result(std::move(status));
callback_.reset();
}
stop();
}
void TransparentProxy::tear_down() {
VLOG(proxy) << "Finish to connect to proxy";
Scheduler::unsubscribe(fd_.get_poll_info().get_pollable_fd_ref());
if (callback_) {
if (!fd_.input_buffer().empty()) {
LOG(ERROR) << "Have " << fd_.input_buffer().size() << " unread bytes";
callback_->set_result(Status::Error("Proxy has sent too many data"));
} else {
callback_->set_result(std::move(fd_));
}
callback_.reset();
}
}
void TransparentProxy::hangup() {
on_error(Status::Error("Canceled"));
}
void TransparentProxy::start_up() {
VLOG(proxy) << "Begin to connect to proxy";
Scheduler::subscribe(fd_.get_poll_info().extract_pollable_fd(this));
set_timeout_in(10);
sync_with_poll(fd_);
if (can_write_local(fd_)) {
loop();
}
}
void TransparentProxy::loop() {
sync_with_poll(fd_);
auto status = [&] {
TRY_STATUS(fd_.flush_read());
TRY_STATUS(loop_impl());
TRY_STATUS(fd_.flush_write());
if (can_close_local(fd_)) {
return Status::Error("Connection closed");
}
return Status::OK();
}();
if (status.is_error()) {
on_error(std::move(status));
}
}
void TransparentProxy::timeout_expired() {
on_error(Status::Error("Connection timeout expired"));
}
} // namespace td

View File

@@ -0,0 +1,57 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/actor/actor.h"
#include "td/utils/BufferedFd.h"
#include "td/utils/common.h"
#include "td/utils/logging.h"
#include "td/utils/port/IPAddress.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/Status.h"
namespace td {
extern int VERBOSITY_NAME(proxy);
class TransparentProxy : public Actor {
public:
class Callback {
public:
Callback() = default;
Callback(const Callback &) = delete;
Callback &operator=(const Callback &) = delete;
virtual ~Callback() = default;
virtual void set_result(Result<BufferedFd<SocketFd>> r_buffered_socket_fd) = 0;
virtual void on_connected() = 0;
};
TransparentProxy(SocketFd socket_fd, IPAddress ip_address, string username, string password,
unique_ptr<Callback> callback, ActorShared<> parent);
protected:
BufferedFd<SocketFd> fd_;
IPAddress ip_address_;
string username_;
string password_;
unique_ptr<Callback> callback_;
ActorShared<> parent_;
void on_error(Status status);
void tear_down() override;
void start_up() override;
void hangup() override;
void loop() override;
void timeout_expired() override;
virtual Status loop_impl() = 0;
};
} // namespace td

158
td/tdnet/td/net/Wget.cpp Normal file
View File

@@ -0,0 +1,158 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#include "td/net/Wget.h"
#include "td/net/HttpHeaderCreator.h"
#include "td/net/HttpOutboundConnection.h"
#include "td/net/SslStream.h"
#include "td/utils/buffer.h"
#include "td/utils/BufferedFd.h"
#include "td/utils/HttpUrl.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/port/IPAddress.h"
#include "td/utils/port/SocketFd.h"
#include "td/utils/Slice.h"
#include "td/utils/SliceBuilder.h"
#include <limits>
namespace td {
Wget::Wget(Promise<unique_ptr<HttpQuery>> promise, string url, std::vector<std::pair<string, string>> headers,
int32 timeout_in, int32 ttl, bool prefer_ipv6, SslCtx::VerifyPeer verify_peer, string content,
string content_type)
: promise_(std::move(promise))
, input_url_(std::move(url))
, headers_(std::move(headers))
, timeout_in_(timeout_in)
, ttl_(ttl)
, prefer_ipv6_(prefer_ipv6)
, verify_peer_(verify_peer)
, content_(std::move(content))
, content_type_(std::move(content_type)) {
}
Status Wget::try_init() {
TRY_RESULT(url, parse_url(input_url_));
TRY_RESULT_ASSIGN(url.host_, idn_to_ascii(url.host_));
HttpHeaderCreator hc;
if (content_.empty()) {
hc.init_get(url.query_);
} else {
hc.init_post(url.query_);
hc.set_content_size(content_.size());
if (!content_type_.empty()) {
hc.set_content_type(content_type_);
}
}
bool was_host = false;
bool was_accept_encoding = false;
for (auto &header : headers_) {
auto header_lower = to_lower(header.first);
if (header_lower == "host") {
was_host = true;
}
if (header_lower == "accept-encoding") {
was_accept_encoding = true;
}
hc.add_header(header.first, header.second);
}
if (!was_host) {
hc.add_header("Host", url.host_);
}
if (!was_accept_encoding) {
hc.add_header("Accept-Encoding", "gzip, deflate");
}
TRY_RESULT(header, hc.finish(content_));
IPAddress addr;
TRY_STATUS(addr.init_host_port(url.host_, url.port_, prefer_ipv6_));
TRY_RESULT(fd, SocketFd::open(addr));
if (fd.empty()) {
return Status::Error("Sockets are not supported");
}
if (url.protocol_ == HttpUrl::Protocol::Http) {
connection_ = create_actor<HttpOutboundConnection>("Connect", BufferedFd<SocketFd>(std::move(fd)), SslStream{},
std::numeric_limits<std::size_t>::max(), 0, 0,
ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
} else {
TRY_RESULT(ssl_ctx, SslCtx::create(CSlice() /* certificate */, verify_peer_));
TRY_RESULT(ssl_stream, SslStream::create(url.host_, std::move(ssl_ctx)));
connection_ = create_actor<HttpOutboundConnection>(
"Connect", BufferedFd<SocketFd>(std::move(fd)), std::move(ssl_stream), std::numeric_limits<std::size_t>::max(),
0, 0, ActorOwn<HttpOutboundConnection::Callback>(actor_id(this)));
}
send_closure(connection_, &HttpOutboundConnection::write_next, BufferSlice(header));
send_closure(connection_, &HttpOutboundConnection::write_ok);
return Status::OK();
}
void Wget::loop() {
if (connection_.empty()) {
auto status = try_init();
if (status.is_error()) {
return on_error(std::move(status));
}
}
}
void Wget::handle(unique_ptr<HttpQuery> result) {
on_ok(std::move(result));
}
void Wget::on_connection_error(Status error) {
on_error(std::move(error));
}
void Wget::on_ok(unique_ptr<HttpQuery> http_query_ptr) {
CHECK(promise_);
CHECK(http_query_ptr);
if ((http_query_ptr->code_ == 301 || http_query_ptr->code_ == 302 || http_query_ptr->code_ == 307 ||
http_query_ptr->code_ == 308) &&
ttl_ > 0) {
LOG(DEBUG) << *http_query_ptr;
input_url_ = http_query_ptr->get_header("location").str();
LOG(DEBUG) << input_url_;
ttl_--;
connection_.reset();
yield();
} else if (http_query_ptr->code_ >= 200 && http_query_ptr->code_ < 300) {
promise_.set_value(std::move(http_query_ptr));
stop();
} else {
on_error(Status::Error(PSLICE() << "HTTP error: " << http_query_ptr->code_));
}
}
void Wget::on_error(Status error) {
CHECK(error.is_error());
CHECK(promise_);
promise_.set_error(std::move(error));
stop();
}
void Wget::start_up() {
set_timeout_in(timeout_in_);
loop();
}
void Wget::timeout_expired() {
on_error(Status::Error("Response timeout expired"));
}
void Wget::tear_down() {
if (promise_) {
on_error(Status::Error("Canceled"));
}
}
} // namespace td

53
td/tdnet/td/net/Wget.h Normal file
View File

@@ -0,0 +1,53 @@
//
// Copyright Aliaksei Levin (levlam@telegram.org), Arseny Smirnov (arseny30@gmail.com) 2014-2024
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#pragma once
#include "td/net/HttpOutboundConnection.h"
#include "td/net/HttpQuery.h"
#include "td/net/SslCtx.h"
#include "td/actor/actor.h"
#include "td/utils/common.h"
#include "td/utils/Promise.h"
#include "td/utils/Status.h"
#include <utility>
namespace td {
class Wget final : public HttpOutboundConnection::Callback {
public:
Wget(Promise<unique_ptr<HttpQuery>> promise, string url, std::vector<std::pair<string, string>> headers = {},
int32 timeout_in = 10, int32 ttl = 3, bool prefer_ipv6 = false,
SslCtx::VerifyPeer verify_peer = SslCtx::VerifyPeer::On, string content = {}, string content_type = {});
private:
Status try_init();
void loop() final;
void handle(unique_ptr<HttpQuery> result) final;
void on_connection_error(Status error) final;
void on_ok(unique_ptr<HttpQuery> http_query_ptr);
void on_error(Status error);
void tear_down() final;
void start_up() final;
void timeout_expired() final;
Promise<unique_ptr<HttpQuery>> promise_;
ActorOwn<HttpOutboundConnection> connection_;
string input_url_;
std::vector<std::pair<string, string>> headers_;
int32 timeout_in_;
int32 ttl_;
bool prefer_ipv6_ = false;
SslCtx::VerifyPeer verify_peer_;
string content_;
string content_type_;
};
} // namespace td