天天看點

基于 POCO 架構的 TCP 連接配接分流程式介紹代碼測試

介紹

下面的程式實作了對 TCP 連接配接的分流,即将一個 TCP 連接配接的流量分布到多個 TCP 連接配接上進行傳輸。

本程式的主要作用是在特定網絡環境下提升通過 TCP 連接配接的 OpenVPN 服務的速率,使之充分利用帶寬。

程式的主要複雜之處在于單生産者—多消費者、多生産者—單消費者兩種同步方式的實作。

代碼

// multisocks, copyright (c) 2013 coolypf

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include <map>
#include <vector>
#include <Poco/Thread.h>
#include <Poco/Event.h>
#include <Poco/ErrorHandler.h>
#include <Poco/FIFOBuffer.h>
#include <Poco/Net/StreamSocket.h>
#include <Poco/Net/ServerSocket.h>

using namespace std;
using namespace Poco;
using namespace Poco::Net;

class error_handler : public ErrorHandler
{
public:
    virtual void exception(const Exception& exc)
    {
        printf("[!] Thread %u: Caught exception: %s\n", (unsigned int)Thread::currentTid(), exc.displayText().c_str());
        fflush(stdout);
        exit(1);
    }

    virtual void exception(const std::exception& exc)
    {
        printf("[!] Thread %u: Caught exception: %s\n", (unsigned int)Thread::currentTid(), exc.what());
        fflush(stdout);
        exit(2);
    }

    virtual void exception()
    {
        printf("[!] Thread %u: Caught unknown exception\n", (unsigned int)Thread::currentTid());
        fflush(stdout);
        exit(3);
    }
};

class
{
    map<string, string> config;
    string empty;

    int read_file(const char *filename, vector<char> &out)
    {
        FILE *fp = fopen(filename, "rb");
        if (!fp)
        {
            printf("[!] Fail to load: %s\n", filename);
            return 1;
        }
        char buf[4096];
        size_t sz;
        while ((sz = fread(buf, 1, 4096, fp)))
            out.insert(out.end(), buf, buf + sz);
        fclose(fp);
        for (size_t i = 0; i < out.size(); ++i)
            if (!out[i])
                out[i] = ' ';
        return 0;
    }

    void split_lines(const vector<char> &in, vector<string> &out)
    {
        size_t pos = 0;
        for (size_t i = 0; i < in.size(); ++i)
        {
            if (in[i] != '\r' && in[i] != '\n')
                continue;
            out.push_back(string(in.begin() + pos, in.begin() + i));
            if (in[i] == '\r' && i + 1 < in.size() && in[i + 1] == '\n')
                i++;
            pos = i + 1;
        }
        if (pos < in.size())
            out.push_back(string(in.begin() + pos, in.end()));
    }

    void trim_string(string &str)
    {
        const char *spaces = " \r\n\t\xb\xc";
        int first = 0, last = (int)str.size() - 1;
        while (first <= last && strchr(spaces, str[first]))
            ++first;
        while (last >= first && strchr(spaces, str[last]))
            --last;
        str.resize(last + 1);
        str.erase(0, first);
    }

    void split_string(const string &in, vector<string> &out, const char *delims)
    {
        size_t pos = 0;
        for (size_t i = 0; i < in.size(); ++i)
        {
            if (!strchr(delims, in[i]))
                continue;
            string part(in.begin() + pos, in.begin() + i);
            trim_string(part);
            out.push_back(part);
            pos = i + 1;
        }
        {
            string part(in.begin() + pos, in.end());
            trim_string(part);
            out.push_back(part);
        }
    }

public:
    int load(const char *filename)
    {
        vector<char> content;
        if (read_file(filename, content))
            return 1;
        vector<string> lines;
        split_lines(content, lines);
        for (size_t i = 0; i < lines.size(); ++i)
        {
            if (lines[i].empty() || lines[i][0] == '#')
                continue;
            vector<string> parts;
            split_string(lines[i], parts, "=");
            if (parts.size() != 2 || parts[0].empty())
            {
                printf("[!] Invalid config: %s (%d)\n", filename, (int)i + 1);
                continue;
            }
            if (config.find(parts[0]) != config.end())
                printf("[!] Override config: %s (%d)\n", filename, (int)i + 1);
            config[parts[0]] = parts[1];
        }
        return 0;
    }

    int i(const char *key) const
    {
        map<string, string>::const_iterator iter = config.find(string(key));
        if (iter == config.end())
        {
            printf("[!] No config: %s\n", key);
            return 0;
        }
        int ret = 0;
        if (sscanf(iter->second.c_str(), "%d", &ret) != 1)
            printf("[!] Invalid int: %s\n", key);
        return ret;
    }

    const string & s(const char *key) const
    {
        map<string, string>::const_iterator iter = config.find(string(key));
        if (iter == config.end())
        {
            printf("[!] No config: %s\n", key);
            return empty;
        }
        return iter->second;
    }
} config;

// Read from out buffer and send to remote
class reader : public Runnable
{
    int index;
    StreamSocket socket;
    Event &readable, &writable;
    FIFOBuffer &buffer;
    bool &self_stopped, &peer_stopped;
public:
    reader(int i, StreamSocket s, Event &r, Event &w, FIFOBuffer &b, bool &self, bool &peer)
        : index(i), socket(s), readable(r), writable(w), buffer(b), self_stopped(self), peer_stopped(peer)
    {
    }

    virtual void run()
    {
        printf("[.] Reader %d, tid = %u\n", index, (unsigned int)Thread::currentTid());
        fflush(stdout);
        char buf[8192];
        bool closed = false;
        while (true)
        {
            if (buffer.isReadable())
            {
                int sz = (int)buffer.read(buf, 8192);
                writable.set();
                int sent = 0;
                while (sent < sz)
                {
                    int sz2;
                    try
                    {
                        sz2 = socket.sendBytes(buf + sent, sz - sent);
                    }
                    catch (Exception &exc)
                    {
                        printf("[!] Reader %d: %s\n", index, exc.displayText().c_str());
                        fflush(stdout);
                        sz2 = 0;
                    }
                    if (sz2 <= 0)
                    {
                        closed = true;
                        break;
                    }
                    sent += sz2;
                }
                if (closed)
                    break;
            }
            else
            {
                if (peer_stopped)
                    break;
                readable.wait();
            }
        }
        try { socket.shutdownSend(); } catch (...) {}
        self_stopped = true;
        writable.set();
    }
};

// Receive from remote and write to in buffer
class writer : public Runnable
{
    int index;
    StreamSocket socket;
    Event &readable, &writable;
    FIFOBuffer &buffer;
    bool &self_stopped, &peer_stopped;
public:
    writer(int i, StreamSocket s, Event &r, Event &w, FIFOBuffer &b, bool &self, bool &peer)
        : index(i), socket(s), readable(r), writable(w), buffer(b), self_stopped(self), peer_stopped(peer)
    {
    }

    virtual void run()
    {
        printf("[.] Writer %d, tid = %u\n", index, (unsigned int)Thread::currentTid());
        fflush(stdout);
        char buf[8192];
        while (true)
        {
            int sz;
            try
            {
                sz = socket.receiveBytes(buf, 8192);
            }
            catch (Exception &exc)
            {
                printf("[!] Writer %d: %s\n", index, exc.displayText().c_str());
                fflush(stdout);
                sz = 0;
            }
            if (sz <= 0)
                break;
            int written = 0;
            while (written < sz)
            {
                if (buffer.isWritable())
                {
                    written += buffer.write(buf + written, sz - written);
                    readable.set();
                }
                else
                {
                    if (peer_stopped)
                        break;
                    writable.wait();
                }
            }
            if (peer_stopped)
                break;
        }
        try { socket.shutdownReceive(); } catch (...) {}
        self_stopped = true;
        readable.set();
    }
};

// Receive from local and write to out buffers
class divider : public Runnable
{
    int nr_conn;
    StreamSocket socket;
    vector<Event *> &readables;
    Event &writable;
    vector<FIFOBuffer *> buffers;
    bool &self_stopped, *peers_stopped;
public:
    divider(int n, StreamSocket s, vector<Event *> &vr, Event &w, vector<FIFOBuffer *> &b, bool &self, bool *peers)
        : nr_conn(n), socket(s), readables(vr), writable(w), buffers(b), self_stopped(self), peers_stopped(peers)
    {
    }

    virtual void run()
    {
        printf("[.] Divider, tid = %u\n", (unsigned int)Thread::currentTid());
        fflush(stdout);
        int sz0 = 8192 * nr_conn;
        char *buf0 = new char[sz0];
        char **bufv = new char *[nr_conn];
        for (int i = 0; i < nr_conn; ++i)
            bufv[i] = new char[8192];
        int *szv = new int[nr_conn];
        int *writtenv = new int[nr_conn];
        long long total = 0;
        bool peer_stopped = false;
        while (true)
        {
            int sz;
            try
            {
                sz = socket.receiveBytes(buf0, sz0);
            }
            catch (Exception &exc)
            {
                printf("[!] Divider: %s\n", exc.displayText().c_str());
                fflush(stdout);
                sz = 0;
            }
            if (sz <= 0)
                break;
            for (int i = 0; i < nr_conn; ++i)
            {
                szv[i] = 0;
                writtenv[i] = 0;
            }
            for (int i = 0; i < sz; ++i)
                bufv[(total + i) % nr_conn][szv[(total + i) % nr_conn]++] = buf0[i];
            total += sz;
            while (true)
            {
                bool done = true;
                for (int i = 0; i < nr_conn; ++i)
                    if (writtenv[i] < szv[i])
                        done = false;
                if (done)
                    break;
                bool written = false;
                for (int i = 0; i < nr_conn; ++i)
                {
                    if (buffers[i]->isWritable())
                    {
                        writtenv[i] += buffers[i]->write(bufv[i] + writtenv[i], szv[i] - writtenv[i]);
                        readables[i]->set();
                        written = true;
                    }
                }
                if (!written)
                {
                    for (int i = 0; i < nr_conn; ++i)
                        if (peers_stopped[i])
                            peer_stopped = true;
                    if (peer_stopped)
                        break;
                    writable.wait();
                    writable.reset();
                }
                if (peer_stopped)
                    break;
            }
        }
        try { socket.shutdownReceive(); } catch (...) {}
        self_stopped = true;
        for (int i = 0; i < nr_conn; ++i)
            readables[i]->set();
    }
};

// Read from in buffers and send to local
class combiner : public Runnable
{
    int nr_conn;
    StreamSocket socket;
    Event &readable;
    vector<Event *> &writables;
    vector<FIFOBuffer *> buffers;
    bool &self_stopped, *peers_stopped;
public:
    combiner(int n, StreamSocket s, Event &r, vector<Event *> &vw, vector<FIFOBuffer *> &b, bool &self, bool *peers)
        : nr_conn(n), socket(s), readable(r), writables(vw), buffers(b), self_stopped(self), peers_stopped(peers)
    {
    }

    virtual void run()
    {
        printf("[.] Combiner, tid = %u\n", (unsigned int)Thread::currentTid());
        fflush(stdout);
        char *buf0 = new char[8192 * nr_conn];
        char **bufv = new char *[nr_conn];
        for (int i = 0; i < nr_conn; ++i)
            bufv[i] = new char[8192];
        int *szv = new int[nr_conn];
        int *readv = new int[nr_conn];
        long long total = 0;
        bool peer_stopped = false, closed = false;
        while (true)
        {
            for (int i = 0; i < nr_conn; ++i)
            {
                szv[i] = buffers[i]->peek(bufv[i], 8192);
                readv[i] = 0;
            }
            int sz = 0;
            while (true)
            {
                int i = (total + sz) % nr_conn;
                if (readv[i] >= szv[i])
                    break;
                buf0[sz++] = bufv[i][readv[i]++];
            }
            total += sz;
            if (sz > 0)
            {
                for (int i = 0; i < nr_conn; ++i)
                {
                    if (readv[i] > 0)
                    {
                        buffers[i]->read(bufv[i], readv[i]);
                        writables[i]->set();
                    }
                }
                int sent = 0;
                while (sent < sz)
                {
                    int sz1;
                    try
                    {
                        sz1 = socket.sendBytes(buf0 + sent, sz - sent);
                    }
                    catch (Exception &exc)
                    {
                        printf("[!] Combiner: %s\n", exc.displayText().c_str());
                        fflush(stdout);
                        sz1 = 0;
                    }
                    if (sz1 <= 0)
                    {
                        closed = true;
                        break;
                    }
                    sent += sz1;
                }
                if (closed)
                    break;
            }
            else
            {
                for (int i = 0; i < nr_conn; ++i)
                    if (peers_stopped[i])
                        peer_stopped = true;
                if (peer_stopped)
                    break;
                readable.wait();
                readable.reset();
            }
        }
        try { socket.shutdownSend(); } catch (...) {}
        self_stopped = true;
        for (int i = 0; i < nr_conn; ++i)
            writables[i]->set();
    }
};

int main(int argc, char **argv)
{
    config.load("multisocks.txt");
    for (int i = 1; i < argc; ++i)
        if (config.load(argv[i]))
            return 1;
    if (!config.s("log").empty())
        freopen(config.s("log").c_str(), "w", stdout);
    ErrorHandler::set(new error_handler);
    fflush(stdout);
    try
    {
        int nr_conn = config.i("nr_conn");
        bool divider_stopped = false, combiner_stopped = false;
        bool *readers_stopped = new bool[nr_conn](), *writers_stopped = new bool[nr_conn]();
        Event in_readable(false), out_writable(false);
        in_readable.reset();
        out_writable.set();
        vector<Event *> in_writables, out_readables;
        for (int i = 0; i < nr_conn; ++i)
        {
            in_writables.push_back(new Event);
            in_writables.back()->set();
            out_readables.push_back(new Event);
            out_readables.back()->reset();
        }
        vector<StreamSocket> remotes;
        if (config.s("remote") == "listen")
        {
            ServerSocket server;
            if (config.s("remote.protocol") == "ipv6")
                server.bind6(config.i("remote.port"), true, true);
            else
                server.bind(config.i("remote.port"), true);
            server.listen();
            for (int i = 0; i < nr_conn; ++i)
            {
                SocketAddress client_addr;
                remotes.push_back(server.acceptConnection(client_addr));
                remotes.back().setNoDelay(true);
                printf("[.] Remote incoming: %s\n", client_addr.toString().c_str());
                fflush(stdout);
                if (client_addr.host() != remotes.front().peerAddress().host())
                    throw Exception(string("client addresses mismatch"));
            }
            server.close();
        }
        else
        {
            SocketAddress server_addr(config.s("remote.host"), config.i("remote.port"));
            for (int i = 0; i < nr_conn; ++i)
            {
                StreamSocket remote;
                if (!config.s("remote.bind").empty())
                    remote.impl()->bind(SocketAddress(config.s("remote.bind"), 0), true);
                remote.connect(server_addr);
                remote.setNoDelay(true);
                remotes.push_back(remote);
                printf("[.] Connected to remote from: %s\n", remote.address().toString().c_str());
                fflush(stdout);
            }
        }
        vector<FIFOBuffer *> in_buffers, out_buffers;
        vector<reader *> readers;
        vector<writer *> writers;
        Thread *reader_threads = new Thread[nr_conn], *writer_threads = new Thread[nr_conn];
        for (int i = 0; i < nr_conn; ++i)
        {
            in_buffers.push_back(new FIFOBuffer(config.i("buffer_size")));
            out_buffers.push_back(new FIFOBuffer(config.i("buffer_size")));
            readers.push_back(new reader(i, remotes[i], *out_readables[i], out_writable, *out_buffers.back(), readers_stopped[i], divider_stopped));
            writers.push_back(new writer(i, remotes[i], in_readable, *in_writables[i], *in_buffers.back(), writers_stopped[i], combiner_stopped));
            reader_threads[i].setStackSize(65536);
            writer_threads[i].setStackSize(65536);
            reader_threads[i].start(*readers.back());
            writer_threads[i].start(*writers.back());
        }
        StreamSocket local;
        if (config.s("local") == "listen")
        {
            ServerSocket server;
            if (config.s("local.protocol") == "ipv6")
                server.bind6(config.i("local.port"), true, true);
            else
                server.bind(config.i("local.port"), true);
            server.listen();
            local = server.acceptConnection();
            printf("[.] Local incoming: %s\n", local.peerAddress().toString().c_str());
            fflush(stdout);
            local.setNoDelay(true);
            server.close();
        }
        else
        {
            SocketAddress server_addr(config.s("local.host"), config.i("local.port"));
            local.connect(server_addr);
            local.setNoDelay(true);
            printf("[.] Connected to local from: %s\n", local.address().toString().c_str());
            fflush(stdout);
        }
        divider *pdivider = new divider(nr_conn, local, out_readables, out_writable, out_buffers, divider_stopped, readers_stopped);
        combiner *pcombiner = new combiner(nr_conn, local, in_readable, in_writables, in_buffers, combiner_stopped, writers_stopped);
        Thread divider_thread, combiner_thread;
        divider_thread.setStackSize(65536);
        combiner_thread.setStackSize(65536);
        divider_thread.start(*pdivider);
        combiner_thread.start(*pcombiner);
        printf("[.] Connection established\n");
        fflush(stdout);
        divider_thread.join();
        combiner_thread.join();
        local.close();
        for (int i = 0; i < nr_conn; ++i)
        {
            reader_threads[i].join();
            writer_threads[i].join();
            remotes[i].close();
        }
        printf("[.] Connection closed\n");
    }
    catch (Exception &exc)
    {
        ErrorHandler::handle(exc);
    }
    catch (exception &exc)
    {
        ErrorHandler::handle(exc);
    }
    catch (...)
    {
        ErrorHandler::handle();
    }
    return 0;
}      

測試

用戶端網絡環境為 CERNET ,伺服器是位于達拉斯的 BurstNET VPS ,連接配接方式是 TCP/IPv6 。

直接連接配接, OpenVPN 速率為 341 KB/s 。使用 5 個連接配接進行分流, OpenVPN 速率可達 1.4 MB/s 。

繼續閱讀