Fix socks5_server_socket overflows

This commit is contained in:
klzgrad
2026-05-07 21:54:29 +08:00
parent b5bda24ebe
commit f74d42c2a7
2 changed files with 56 additions and 58 deletions

View File

@@ -2,19 +2,17 @@
// Copyright 2018 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
#pragma allow_unsafe_buffers
#endif
#include "net/tools/naive/socks5_server_socket.h"
#include <cstring>
#include <cstdint>
#include <utility>
#include "base/containers/extend.h"
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/logging.h"
#include "base/numerics/byte_conversions.h"
#include "base/strings/string_view_util.h"
#include "base/sys_byteorder.h"
#include "net/base/ip_address.h"
#include "net/base/net_errors.h"
@@ -33,16 +31,16 @@ enum SocksCommandType {
static constexpr unsigned int kGreetReadHeaderSize = 2;
static constexpr unsigned int kAuthReadHeaderSize = 2;
static constexpr unsigned int kReadHeaderSize = 5;
static constexpr char kSOCKS5Version = '\x05';
static constexpr char kSOCKS5Reserved = '\x00';
static constexpr char kAuthMethodNone = '\x00';
static constexpr char kAuthMethodUserPass = '\x02';
static constexpr char kAuthMethodNoAcceptable = '\xff';
static constexpr char kSubnegotiationVersion = '\x01';
static constexpr char kAuthStatusSuccess = '\x00';
static constexpr char kAuthStatusFailure = '\xff';
static constexpr char kReplySuccess = '\x00';
static constexpr char kReplyCommandNotSupported = '\x07';
static constexpr uint8_t kSOCKS5Version = '\x05';
static constexpr uint8_t kSOCKS5Reserved = '\x00';
static constexpr uint8_t kAuthMethodNone = '\x00';
static constexpr uint8_t kAuthMethodUserPass = '\x02';
static constexpr uint8_t kAuthMethodNoAcceptable = '\xff';
static constexpr uint8_t kSubnegotiationVersion = '\x01';
static constexpr uint8_t kAuthStatusSuccess = '\x00';
static constexpr uint8_t kAuthStatusFailure = '\xff';
static constexpr uint8_t kReplySuccess = '\x00';
static constexpr uint8_t kReplyCommandNotSupported = '\x07';
static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4");
static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6");
@@ -318,7 +316,7 @@ int Socks5ServerSocket::DoGreetReadComplete(int result) {
return ERR_SOCKS_CONNECTION_FAILED;
}
buffer_.append(handshake_buf_->data(), result);
base::Extend(buffer_, handshake_buf_->first(result));
// When the first few bytes are read, check how many more are required
// and accordingly increase them
@@ -328,7 +326,7 @@ int Socks5ServerSocket::DoGreetReadComplete(int result) {
"version", buffer_[0]);
return ERR_SOCKS_CONNECTION_FAILED;
}
int nmethods = buffer_[1];
uint8_t nmethods = buffer_[1];
if (nmethods == 0) {
net_log_.AddEvent(NetLogEventType::SOCKS_NO_REQUESTED_AUTH);
return ERR_SOCKS_CONNECTION_FAILED;
@@ -340,17 +338,18 @@ int Socks5ServerSocket::DoGreetReadComplete(int result) {
}
if (buffer_.size() == read_header_size_) {
int nmethods = buffer_[1];
char expected_method = kAuthMethodNone;
uint8_t nmethods = buffer_[1];
uint8_t expected_method = kAuthMethodNone;
if (!user_.empty() || !pass_.empty()) {
expected_method = kAuthMethodUserPass;
}
void* match =
std::memchr(&buffer_[kGreetReadHeaderSize], expected_method, nmethods);
if (match) {
auth_method_ = expected_method;
} else {
auth_method_ = kAuthMethodNoAcceptable;
auth_method_ = kAuthMethodNoAcceptable;
for (uint8_t method :
base::span(buffer_).subspan(kGreetReadHeaderSize, nmethods)) {
if (method == expected_method) {
auth_method_ = expected_method;
}
}
buffer_.clear();
next_state_ = STATE_GREET_WRITE;
@@ -363,8 +362,7 @@ int Socks5ServerSocket::DoGreetReadComplete(int result) {
int Socks5ServerSocket::DoGreetWrite() {
if (buffer_.empty()) {
const char write_data[] = {kSOCKS5Version, auth_method_};
buffer_ = std::string(write_data, std::size(write_data));
buffer_ = {kSOCKS5Version, auth_method_};
bytes_sent_ = 0;
}
@@ -372,8 +370,8 @@ int Socks5ServerSocket::DoGreetWrite() {
int handshake_buf_len = buffer_.size() - bytes_sent_;
DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
std::memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
handshake_buf_len);
handshake_buf_->span().copy_from(base::span(buffer_).subspan(
bytes_sent_, static_cast<size_t>(handshake_buf_len)));
return transport_->Write(handshake_buf_.get(), handshake_buf_len,
io_callback_, traffic_annotation_);
}
@@ -423,7 +421,7 @@ int Socks5ServerSocket::DoAuthReadComplete(int result) {
return ERR_SOCKS_CONNECTION_FAILED;
}
buffer_.append(handshake_buf_->data(), result);
base::Extend(buffer_, handshake_buf_->first(result));
// When the first few bytes are read, check how many more are required
// and accordingly increase them
@@ -433,15 +431,15 @@ int Socks5ServerSocket::DoAuthReadComplete(int result) {
"version", buffer_[0]);
return ERR_SOCKS_CONNECTION_FAILED;
}
int username_len = buffer_[1];
size_t username_len = buffer_[1];
read_header_size_ += username_len + 1;
next_state_ = STATE_AUTH_READ;
return OK;
}
if (buffer_.size() == read_header_size_) {
int username_len = buffer_[1];
int password_len = buffer_[kAuthReadHeaderSize + username_len];
size_t username_len = buffer_[1];
size_t password_len = buffer_[kAuthReadHeaderSize + username_len];
size_t password_offset = kAuthReadHeaderSize + username_len + 1;
if (buffer_.size() == password_offset && password_len != 0) {
read_header_size_ += password_len;
@@ -449,8 +447,10 @@ int Socks5ServerSocket::DoAuthReadComplete(int result) {
return OK;
}
if (buffer_.compare(kAuthReadHeaderSize, username_len, user_) == 0 &&
buffer_.compare(password_offset, password_len, pass_) == 0) {
if (base::span(buffer_).subspan(kAuthReadHeaderSize, username_len) ==
base::as_byte_span(user_) &&
base::span(buffer_).subspan(password_offset, password_len) ==
base::as_byte_span(pass_)) {
auth_status_ = kAuthStatusSuccess;
} else {
auth_status_ = kAuthStatusFailure;
@@ -466,8 +466,7 @@ int Socks5ServerSocket::DoAuthReadComplete(int result) {
int Socks5ServerSocket::DoAuthWrite() {
if (buffer_.empty()) {
const char write_data[] = {kSubnegotiationVersion, auth_status_};
buffer_ = std::string(write_data, std::size(write_data));
buffer_ = {kSubnegotiationVersion, auth_status_};
bytes_sent_ = 0;
}
@@ -475,8 +474,8 @@ int Socks5ServerSocket::DoAuthWrite() {
int handshake_buf_len = buffer_.size() - bytes_sent_;
DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
std::memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
handshake_buf_len);
handshake_buf_->span().copy_from(base::span(buffer_).subspan(
bytes_sent_, static_cast<size_t>(handshake_buf_len)));
return transport_->Write(handshake_buf_.get(), handshake_buf_len,
io_callback_, traffic_annotation_);
}
@@ -526,7 +525,7 @@ int Socks5ServerSocket::DoHandshakeReadComplete(int result) {
return ERR_SOCKS_CONNECTION_FAILED;
}
buffer_.append(handshake_buf_->data(), result);
base::Extend(buffer_, handshake_buf_->first(result));
// When the first few bytes are read, check how many more are required
// and accordingly increase them
@@ -556,7 +555,7 @@ int Socks5ServerSocket::DoHandshakeReadComplete(int result) {
// read, we substract 1 byte from the additional request size.
address_type_ = static_cast<SocksEndPointAddressType>(buffer_[3]);
if (address_type_ == kEndPointDomain) {
address_size_ = static_cast<uint8_t>(buffer_[4]);
address_size_ = buffer_[4];
if (address_size_ == 0) {
net_log_.AddEvent(NetLogEventType::SOCKS_ZERO_LENGTH_DOMAIN);
return ERR_SOCKS_CONNECTION_FAILED;
@@ -583,18 +582,17 @@ int Socks5ServerSocket::DoHandshakeReadComplete(int result) {
// When the final bytes are read, setup handshake.
if (buffer_.size() == read_header_size_) {
size_t port_start = read_header_size_ - sizeof(uint16_t);
uint16_t port_net;
std::memcpy(&port_net, &buffer_[port_start], sizeof(uint16_t));
uint16_t port_host = base::NetToHost16(port_net);
uint16_t port_host = base::U16FromBigEndian(
base::span(buffer_).subspan(port_start).first<2>());
size_t address_start = port_start - address_size_;
base::span<const uint8_t> addr_span =
base::span(buffer_).subspan(address_start, address_size_);
if (address_type_ == kEndPointDomain) {
std::string domain(&buffer_[address_start], address_size_);
request_endpoint_ = HostPortPair(domain, port_host);
request_endpoint_ =
HostPortPair(base::as_string_view(addr_span), port_host);
} else {
IPAddress ip_addr(base::span<const uint8_t>{
reinterpret_cast<const uint8_t*>(&buffer_[address_start]),
static_cast<size_t>(address_size_)});
IPAddress ip_addr(addr_span);
IPEndPoint endpoint(ip_addr, port_host);
request_endpoint_ = HostPortPair::FromIPEndPoint(endpoint);
}
@@ -612,7 +610,7 @@ int Socks5ServerSocket::DoHandshakeWrite() {
next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
if (buffer_.empty()) {
const char write_data[] = {
buffer_ = {
// clang-format off
kSOCKS5Version,
reply_,
@@ -622,14 +620,14 @@ int Socks5ServerSocket::DoHandshakeWrite() {
0x00, 0x00, // BND.PORT
// clang-format on
};
buffer_ = std::string(write_data, std::size(write_data));
bytes_sent_ = 0;
}
int handshake_buf_len = buffer_.size() - bytes_sent_;
DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
std::memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], handshake_buf_len);
handshake_buf_->span().copy_from(base::span(buffer_).subspan(
bytes_sent_, static_cast<size_t>(handshake_buf_len)));
return transport_->Write(handshake_buf_.get(), handshake_buf_len,
io_callback_, traffic_annotation_);
}

View File

@@ -131,7 +131,7 @@ class Socks5ServerSocket : public StreamSocket {
// While writing, this buffer stores the complete write handshake data.
// While reading, it stores the handshake information received so far.
std::string buffer_;
std::vector<uint8_t> buffer_;
// This becomes true when the SOCKS handshake has completed and the
// overlying connection is free to communicate.
@@ -145,13 +145,13 @@ class Socks5ServerSocket : public StreamSocket {
bool was_ever_used_;
SocksEndPointAddressType address_type_;
int address_size_;
size_t address_size_;
std::string user_;
std::string pass_;
char auth_method_;
char auth_status_;
char reply_;
uint8_t auth_method_;
uint8_t auth_status_;
uint8_t reply_;
HostPortPair request_endpoint_;