Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Downloader threadsafe and more thread compliant. #886

Merged
merged 9 commits into from
Feb 8, 2023
139 changes: 124 additions & 15 deletions include/downloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <map>
#include <memory>
#include <stdexcept>
#include <mutex>

namespace kiwix
{
Expand All @@ -43,6 +44,14 @@ class AriaError : public std::runtime_error {
};


/**
* A representation of a current download.
*
* `Download` is not thread safe. User must care to not call method on a
* same download from different threads.
* However, it is safe to use different `Download`s from different threads.
*/

class Download {
public:
typedef enum { K_ACTIVE, K_WAITING, K_PAUSED, K_ERROR, K_COMPLETE, K_REMOVED, K_UNKNOWN } StatusResult;
Expand All @@ -53,19 +62,89 @@ class Download {
: mp_aria(p_aria),
m_status(K_UNKNOWN),
m_did(did) {};
void updateStatus(bool follow=false);

/**
* Update the status of the download.
*
* This call make an aria rpc call and is blocking.
* Some download (started with a metalink) are in fact several downloads.
* - A first one to download the metadlink.
* - A second one to download the real file.
*
* If `follow` is true, updateStatus tries to detect that and tracks
* the second download when the first one is finished.
* By passing false to `follow`, `Download` will only track the first download.
*
* `getFoo` methods are based on the last statusUpdate.
*
* @param follow: Do we have to follow following downloads.
*/
void updateStatus(bool follow);

/**
* Pause the download (and call updateStatus)
*/
void pauseDownload();

/**
* Resume the download (and call updateStatus)
*/
void resumeDownload();

/**
* Cancel the download.
*
* A canceled downlod cannot be resume and updateStatus does nothing.
* However, you can still get information based on the last known information.
*/
void cancelDownload();
StatusResult getStatus() { return m_status; }
std::string getDid() { return m_did; }
std::string getFollowedBy() { return m_followedBy; }
uint64_t getTotalLength() { return m_totalLength; }
uint64_t getCompletedLength() { return m_completedLength; }
uint64_t getDownloadSpeed() { return m_downloadSpeed; }
uint64_t getVerifiedLength() { return m_verifiedLength; }
std::string getPath() { return m_path; }
std::vector<std::string>& getUris() { return m_uris; }

/*
* Get the status of the download.
*/
StatusResult getStatus() const { return m_status; }

/*
* Get the id of the download.
*/
const std::string& getDid() const { return m_did; }

/*
* Get the id of the "second" download.
*
* Set only if the "first" download is a metalink and is complete.
*/
const std::string& getFollowedBy() const { return m_followedBy; }

/*
* Get the total length of the download.
*/
uint64_t getTotalLength() const { return m_totalLength; }

/*
* Get the completed length of the download.
*/
uint64_t getCompletedLength() const { return m_completedLength; }

/*
* Get the download speed of the download.
*/
uint64_t getDownloadSpeed() const { return m_downloadSpeed; }

/*
* Get the verified length of the download.
*/
uint64_t getVerifiedLength() const { return m_verifiedLength; }

/*
* Get the path (local file) of the download.
*/
const std::string& getPath() const { return m_path; }

/*
* Get the download uris of the download.
*/
const std::vector<std::string>& getUris() const { return m_uris; }

protected:
std::shared_ptr<Aria2> mp_aria;
Expand All @@ -83,6 +162,9 @@ class Download {
/**
* A tool to download things.
*
* A Downloader manages `Download` using aria2 in the background.
* `Downloader` is threadsafe.
* However, the returned `Download`s are NOT threadsafe.
*/
class Downloader
{
Expand All @@ -92,14 +174,41 @@ class Downloader

void close();

Download* startDownload(const std::string& uri, const std::vector<std::pair<std::string, std::string>>& options = {});
Download* getDownload(const std::string& did);
/**
* Start a new download.
*
* This method is thread safe and return a pointer to a newly created `Download`.
* User should call `update` on the returned `Download` to have an accurate status.
*
* @param uri: The uri of the thing to download.
* @param options: A series of pair <option_name, option_value> to pass to aria.
* @return: The newly created Download.
*/
std::shared_ptr<Download> startDownload(const std::string& uri, const std::vector<std::pair<std::string, std::string>>& options = {});

/**
* Get a download corrsponding to a download id (did)
* User should call `update` on the returned `Download` to have an accurate status.
*
* @param did: The download id to search for.
* @return: The Download corresponding to did.
* @throw: Throw std::out_of_range if did is not found.
*/
std::shared_ptr<Download> getDownload(const std::string& did);

/**
* Get the number of downloads currently managed.
*/
size_t getNbDownload() const;

size_t getNbDownload() { return m_knownDownloads.size(); }
std::vector<std::string> getDownloadIds();
/**
* Get the ids of the managed downloads.
*/
std::vector<std::string> getDownloadIds() const;

private:
std::map<std::string, std::unique_ptr<Download>> m_knownDownloads;
mutable std::mutex m_lock;
std::map<std::string, std::shared_ptr<Download>> m_knownDownloads;
std::shared_ptr<Aria2> mp_aria;
};
}
Expand Down
61 changes: 30 additions & 31 deletions src/aria2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,15 @@
#define LOG_ARIA_ERROR() \
{ \
std::cerr << "ERROR: aria2 RPC request failed. (" << res << ")." << std::endl; \
std::cerr << (m_curlErrorBuffer[0] ? m_curlErrorBuffer.get() : curl_easy_strerror(res)) << std::endl; \
std::cerr << (curlErrorBuffer[0] ? curlErrorBuffer : curl_easy_strerror(res)) << std::endl; \
}

namespace kiwix {

Aria2::Aria2():
mp_aria(nullptr),
m_port(42042),
m_secret(getNewRpcSecret()),
m_curlErrorBuffer(new char[CURL_ERROR_SIZE]),
mp_curl(nullptr)
m_secret(getNewRpcSecret())
{
m_downloadDir = getDataDirectory();
makeDirectory(m_downloadDir);
Expand Down Expand Up @@ -91,36 +89,32 @@ Aria2::Aria2():
launchCmd.append(cmd).append(" ");
}
mp_aria = Subprocess::run(callCmd);
mp_curl = curl_easy_init();

curl_easy_setopt(mp_curl, CURLOPT_URL, "http://localhost/rpc");
curl_easy_setopt(mp_curl, CURLOPT_PORT, m_port);
curl_easy_setopt(mp_curl, CURLOPT_POST, 1L);
curl_easy_setopt(mp_curl, CURLOPT_ERRORBUFFER, m_curlErrorBuffer.get());
CURL* p_curl = curl_easy_init();
char curlErrorBuffer[CURL_ERROR_SIZE];

curl_easy_setopt(p_curl, CURLOPT_URL, "http://localhost/rpc");
curl_easy_setopt(p_curl, CURLOPT_PORT, m_port);
curl_easy_setopt(p_curl, CURLOPT_POST, 1L);
curl_easy_setopt(p_curl, CURLOPT_ERRORBUFFER, curlErrorBuffer);

int watchdog = 50;
while(--watchdog) {
sleep(10);
m_curlErrorBuffer[0] = 0;
auto res = curl_easy_perform(mp_curl);
curlErrorBuffer[0] = 0;
auto res = curl_easy_perform(p_curl);
if (res == CURLE_OK) {
break;
} else if (watchdog == 1) {
LOG_ARIA_ERROR();
}
}
curl_easy_cleanup(p_curl);
if (!watchdog) {
curl_easy_cleanup(mp_curl);
throw std::runtime_error("Cannot connect to aria2c rpc. Aria2c launch cmd : " + launchCmd);
}
}

Aria2::~Aria2()
{
std::unique_lock<std::mutex> lock(m_lock);
curl_easy_cleanup(mp_curl);
}

void Aria2::close()
{
saveSession();
Expand All @@ -140,20 +134,25 @@ std::string Aria2::doRequest(const MethodCall& methodCall)
std::stringstream outStream;
CURLcode res;
long response_code;
{
std::unique_lock<std::mutex> lock(m_lock);
curl_easy_setopt(mp_curl, CURLOPT_POSTFIELDSIZE, requestContent.size());
curl_easy_setopt(mp_curl, CURLOPT_POSTFIELDS, requestContent.c_str());
curl_easy_setopt(mp_curl, CURLOPT_WRITEFUNCTION, &write_callback_to_iss);
curl_easy_setopt(mp_curl, CURLOPT_WRITEDATA, &outStream);
m_curlErrorBuffer[0] = 0;
res = curl_easy_perform(mp_curl);
if (res != CURLE_OK) {
LOG_ARIA_ERROR();
throw std::runtime_error("Cannot perform request");
}
curl_easy_getinfo(mp_curl, CURLINFO_RESPONSE_CODE, &response_code);
char curlErrorBuffer[CURL_ERROR_SIZE];
CURL* p_curl = curl_easy_init();
curl_easy_setopt(p_curl, CURLOPT_URL, "http://localhost/rpc");
curl_easy_setopt(p_curl, CURLOPT_PORT, m_port);
curl_easy_setopt(p_curl, CURLOPT_POST, 1L);
curl_easy_setopt(p_curl, CURLOPT_ERRORBUFFER, curlErrorBuffer);
curl_easy_setopt(p_curl, CURLOPT_POSTFIELDSIZE, requestContent.size());
curl_easy_setopt(p_curl, CURLOPT_POSTFIELDS, requestContent.c_str());
curl_easy_setopt(p_curl, CURLOPT_WRITEFUNCTION, &write_callback_to_iss);
curl_easy_setopt(p_curl, CURLOPT_WRITEDATA, &outStream);
curlErrorBuffer[0] = 0;
res = curl_easy_perform(p_curl);
if (res != CURLE_OK) {
LOG_ARIA_ERROR();
curl_easy_cleanup(p_curl);
throw std::runtime_error("Cannot perform request");
}
curl_easy_getinfo(p_curl, CURLINFO_RESPONSE_CODE, &response_code);
curl_easy_cleanup(p_curl);

auto responseContent = outStream.str();
if (response_code != 200) {
Expand Down
7 changes: 1 addition & 6 deletions src/aria2.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "xmlrpc.h"

#include <memory>
#include <mutex>
#include <curl/curl.h>

namespace kiwix {
Expand All @@ -24,15 +23,11 @@ class Aria2
int m_port;
std::string m_secret;
std::string m_downloadDir;
std::unique_ptr<char[]> m_curlErrorBuffer;
CURL* mp_curl;
std::mutex m_lock;

std::string doRequest(const MethodCall& methodCall);

public:
Aria2();
virtual ~Aria2();
virtual ~Aria2() = default;
void close();

std::string addUri(const std::vector<std::string>& uri, const std::vector<std::pair<std::string, std::string>>& options = {});
Expand Down
Loading