diff --git a/include/eepp/network/http.hpp b/include/eepp/network/http.hpp index ac834fd9f..69b58f96e 100644 --- a/include/eepp/network/http.hpp +++ b/include/eepp/network/http.hpp @@ -109,6 +109,27 @@ class EE_API Http : NonCopyable { /** Enables/Disables follow redirects */ void setFollowRedirect( bool follow ); + + /** Definition of the current progress callback + * @param http The http client + * @param request The http request + * @param totalBytes The total bytes of the document / files ( only available if Content-Length is returned, otherwise is 0 ) + * @param currentBytes Current received total bytes + * @return True if continue the request, false will cancel the current request. + */ + typedef std::function ProgressCallback; + + /** Sets a progress callback */ + void setProgressCallback( const ProgressCallback& progressCallback ); + + /** Get the progress callback */ + const ProgressCallback& getProgressCallback() const; + + /** Cancels the current request if being processed */ + void cancel(); + + /** @return True if the current request was cancelled */ + const bool& isCancelled() const; private: friend class Http; @@ -128,16 +149,18 @@ class EE_API Http : NonCopyable { typedef std::map FieldTable; // Member data - FieldTable mFields; ///< Fields of the header associated to their value - Method mMethod; ///< Method to use for the request - std::string mUri; ///< Target URI of the request - unsigned int mMajorVersion; ///< Major HTTP version - unsigned int mMinorVersion; ///< Minor HTTP version - std::string mBody; ///< Body of the request - bool mValidateCertificate; ///< Validates the SSL certificate in case of an HTTPS request - bool mValidateHostname; ///< Validates the hostname in case of an HTTPS request - bool mFollowRedirect; ///< Follows redirect response codes - unsigned int mRedirectionCount; ///< Number of redirections followed by the request + FieldTable mFields; ///< Fields of the header associated to their value + Method mMethod; ///< Method to use for the request + std::string mUri; ///< Target URI of the request + unsigned int mMajorVersion; ///< Major HTTP version + unsigned int mMinorVersion; ///< Minor HTTP version + std::string mBody; ///< Body of the request + bool mValidateCertificate; ///< Validates the SSL certificate in case of an HTTPS request + bool mValidateHostname; ///< Validates the hostname in case of an HTTPS request + bool mFollowRedirect; ///< Follows redirect response codes + mutable bool mCancel; ///< Cancel state of current request + ProgressCallback mProgressCallback; ///< Progress callback + mutable unsigned int mRedirectionCount; ///< Number of redirections followed by the request }; /** @brief Define a HTTP response */ diff --git a/src/eepp/network/http.cpp b/src/eepp/network/http.cpp index a2f6fea38..f22e1d638 100644 --- a/src/eepp/network/http.cpp +++ b/src/eepp/network/http.cpp @@ -17,6 +17,7 @@ Http::Request::Request(const std::string& uri, Method method, const std::string& mValidateCertificate( validateCertificate ), mValidateHostname( validateHostname ), mFollowRedirect( followRedirect ), + mCancel( false ), mRedirectionCount( 0 ) { setMethod(method); @@ -78,6 +79,22 @@ void Http::Request::setFollowRedirect(bool follow) { mFollowRedirect = follow; } +void Http::Request::setProgressCallback(const Http::Request::ProgressCallback& progressCallback) { + mProgressCallback = progressCallback; +} + +const Http::Request::ProgressCallback& Http::Request::getProgressCallback() const { + return mProgressCallback; +} + +void Http::Request::cancel() { + mCancel = true; +} + +const bool &Http::Request::isCancelled() const { + return mCancel; +} + std::string Http::Request::prepare() const { std::ostringstream out; @@ -365,7 +382,7 @@ Http::Response Http::sendRequest(const Http::Request& request, Time timeout) { if ( ( received.getStatus() == Response::MovedPermanently || received.getStatus() == Response::MovedTemporarily ) && request.getFollowRedirect() ) { - const_cast( request ).mRedirectionCount++; + request.mRedirectionCount++; // Only continue redirecting if less than 10 redirections were done if ( request.mRedirectionCount < 10 ) { @@ -378,7 +395,7 @@ Http::Response Http::sendRequest(const Http::Request& request, Time timeout) { } } - return received; + return std::move(received); } Http::Response Http::downloadRequest(const Http::Request& request, IOStream& writeTo, Time timeout) { @@ -391,27 +408,50 @@ Http::Response Http::downloadRequest(const Http::Request& request, IOStream& wri mConnection = Conn; } + // First make sure that the request is valid -- add missing mandatory fields Request toSend(prepareFields(request)); + + // Prepare the response Response received; + // Connect the socket to the host if (mConnection->connect(mHost, mPort, timeout) == Socket::Done) { + // Convert the request to string and send it through the connected socket std::string requestStr = toSend.prepare(); if (!requestStr.empty()) { + // Send it through the socket if (mConnection->send(requestStr.c_str(), requestStr.size()) == Socket::Done) { + // Wait for the server's response int isnheader = 0; - size_t len = 0; + std::size_t currentTotalBytes = 0; + std::size_t len = 0; char * eol; // end of line char * bol; // beginning of line std::size_t size = 0; - const size_t bufferSize = 1024; + const std::size_t bufferSize = 1024; char buffer[bufferSize+1]; std::string header; - while (mConnection->receive(buffer, bufferSize, size) == Socket::Done) { - if ( isnheader != 0 ) + while (!request.isCancelled() && mConnection->receive(buffer, bufferSize, size) == Socket::Done) { + if ( isnheader != 0 ) { + currentTotalBytes += size; writeTo.write( buffer, size ); + if ( request.getProgressCallback() ) { + std::size_t length = 0; + + if ( !received.getField("content-length").empty() ) { + String::fromString( length, received.getField("content-length") ); + } + + if ( !request.getProgressCallback()( *this, request, length, currentTotalBytes ) ) { + request.mCancel = true; + break; + } + } + } + if ( isnheader == 0 ) { // calculate combined length of unprocessed data and new data len += size; @@ -421,15 +461,19 @@ Http::Response Http::downloadRequest(const Http::Request& request, IOStream& wri // checks if the header break happened to be the first line of the buffer if ( !( strncmp( buffer, "\r\n", 2 ) ) ) { - if (len > 2) + if (len > 2) { + currentTotalBytes += (len-2); writeTo.write(buffer, (len-2)); + } continue; } if ( !( strncmp( buffer, "\n", 1 ) ) ) { - if ( len > 1 ) + if ( len > 1 ) { + currentTotalBytes += (len-1); writeTo.write(buffer, (len-1)); + } continue; } @@ -457,14 +501,43 @@ Http::Response Http::downloadRequest(const Http::Request& request, IOStream& wri len = len - ( bol - buffer ); // write remaining data to FILE stream - if ( len > 0 ) + if ( len > 0 ) { + currentTotalBytes += len; writeTo.write( bol, len ); + } header.append( buffer, ( bol - buffer ) ); // reset length of left over data to zero and continue processing // non-header information len = 0; + + if ( !header.empty() ) { + // Build the Response object from the received data + received.parse(header); + + // If a redirection is requested, and requests follows redirections, + // send a new request to the redirection location. + if ( ( received.getStatus() == Response::MovedPermanently || received.getStatus() == Response::MovedTemporarily ) && + request.getFollowRedirect() ) { + + request.mRedirectionCount++; + + // Only continue redirecting if less than 10 redirections were done + if ( request.mRedirectionCount < 10 ) { + std::string location( received.getField("location") ); + URI uri( location ); + Http http( uri.getHost(), uri.getPort(), uri.getScheme() == "https" ? true : false ); + Http::Request newRequest( request ); + newRequest.setUri( uri.getPathEtc() ); + + // Close the connection + mConnection->disconnect(); + + return http.downloadRequest( request, writeTo, timeout ); + } + } + } } } @@ -473,32 +546,6 @@ Http::Response Http::downloadRequest(const Http::Request& request, IOStream& wri } } } - - if ( !header.empty() ) { - received.parse(header); - - // If a redirection is requested, and requests follows redirections, - // send a new request to the redirection location. - if ( ( received.getStatus() == Response::MovedPermanently || received.getStatus() == Response::MovedTemporarily ) && - request.getFollowRedirect() ) { - - const_cast( request ).mRedirectionCount++; - - // Only continue redirecting if less than 10 redirections were done - if ( request.mRedirectionCount < 10 ) { - std::string location( received.getField("location") ); - URI uri( location ); - Http http( uri.getHost(), uri.getPort(), uri.getScheme() == "https" ? true : false ); - Http::Request newRequest( request ); - newRequest.setUri( uri.getPathEtc() ); - - // Close the connection - mConnection->disconnect(); - - return http.downloadRequest( request, writeTo, timeout ); - } - } - } } } @@ -506,7 +553,7 @@ Http::Response Http::downloadRequest(const Http::Request& request, IOStream& wri mConnection->disconnect(); } - return received; + return std::move(received); } Http::Response Http::downloadRequest(const Http::Request & request, std::string writePath, Time timeout) { @@ -595,7 +642,7 @@ void Http::removeOldThreads() { } } -Http::Request Http::prepareFields(const Http::Request & request) { +Http::Request Http::prepareFields(const Http::Request& request) { Request toSend(request); if (!toSend.hasField("User-Agent")) { @@ -620,7 +667,7 @@ Http::Request Http::prepareFields(const Http::Request & request) { toSend.setField("Connection", "close"); } - return toSend; + return std::move(toSend); } void Http::sendAsyncRequest( AsyncResponseCallback cb, const Http::Request& request, Time timeout ) { diff --git a/src/examples/http_request/http_request.cpp b/src/examples/http_request/http_request.cpp index f821f9647..c20488ec7 100644 --- a/src/examples/http_request/http_request.cpp +++ b/src/examples/http_request/http_request.cpp @@ -1,12 +1,23 @@ #include #include +void printResponseHeaders( Http::Response& response ) { + Http::Response::FieldTable headers = response.getHeaders(); + + std::cout << "\r\nHeaders: " << std::endl; + + for ( auto&& head : headers ) { + std::cout << "\t" << head.first << ": " << head.second << std::endl; + } +} + EE_MAIN_FUNC int main (int argc, char * argv []) { args::ArgumentParser parser("HTTP request program example"); args::HelpFlag help(parser, "help", "Display this help menu", {'h', "help"}); args::ValueFlag output(parser, "file", "Write to file instead of stdout", {'o', "output"} ); + args::Flag head(parser, "head", "Show document info", {'I',"head"} ); + args::Flag progress(parser, "progress", "Show current progress of a download", {'p',"progress"} ); args::Positional url(parser, "url", "The url to request"); - args::Flag verbose(parser, "verbose", "Prints the request response headers", {'v',"verbose"} ); try { parser.ParseCLI(argc, argv); @@ -73,14 +84,8 @@ EE_MAIN_FUNC int main (int argc, char * argv []) { Http::Response::Status status = response.getStatus(); if ( status == Http::Response::Ok ) { - if ( verbose ) { - Http::Response::FieldTable headers = response.getHeaders(); - - std::cout << "Headers: " << std::endl; - - for ( auto&& head : headers ) { - std::cout << "\t" << head.first << ": " << head.second << std::endl; - } + if ( head ) { + printResponseHeaders(response); std::cout << std::endl << "Body: " << std::endl; } @@ -90,12 +95,23 @@ EE_MAIN_FUNC int main (int argc, char * argv []) { std::cout << "Error " << status << std::endl; } } else { - http.downloadRequest(request, output.Get(), Seconds(5)); + if ( progress ) { + request.setProgressCallback( []( const Http& http, const Http::Request& request, size_t totalBytes, size_t currentBytes ) { + std::cout << "\rDownloaded " << FileSystem::sizeToString( currentBytes ).c_str() << " of " << FileSystem::sizeToString( totalBytes ).c_str() << " "; + std::cout << std::flush; + return true; + }); + } + + Http::Response response = http.downloadRequest(request, output.Get(), Seconds(5)); + + if ( head ) + printResponseHeaders(response); } } } - if ( verbose ) + if ( head ) MemoryManager::showResults(); return EXIT_SUCCESS;