自定义线程池

自定义线程池

自定义线程池_第1张图片

注意: 需要c++17c++20 的支持

代码实现:

#ifndef THREADPOOL_H
#define THREADPOOL_H

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

//线程池,单例类
class ThreadPool
{
public:
    static ThreadPool& getInstance()
    {
        //std::thread::hardware_concurrency() //获取线程数  最小10个线程
        static int size = std::max((int)std::thread::hardware_concurrency(),10);
        static ThreadPool instance(size,size * 2);
        return instance;
    }

    ThreadPool(ThreadPool const&) = delete;
    void operator=(ThreadPool const&) = delete;

    //增加运行的函数 动态获取类型
    template<typename F, typename... Args>
    auto enqueue(F&& f, Args&&... args) -> std::future<typename std::invoke_result<F,Args...>::type>
    {
        //返回类型
        using return_type = typename std::invoke_result<F,Args...>::type;
        //打包函数
        auto task = std::make_shared<std::packaged_task<return_type()> >(
            std::bind(std::forward<F>(f), std::forward<Args>(args)...)
            );

        std::future<return_type> res = task->get_future();
        {
            std::lock_guard <std::mutex> lock(m_queue_mutex);

            if (m_stop)
            {
                //return res;
                throw std::runtime_error("enqueue on stopped ThreadPool");
            }

            // 添加线程来运行任务
            // 检查是否有空闲线程,否则创建新线程
            bool assigned = true;
            for(const auto& it : m_workers)
            {
                if(it->status ==  Status::Idle)
                {
                    //标记有空闲线程
                    assigned = false;
                    break;
                }
            }
            //创建新线程
            if (assigned)
            {
                add_thread();
            }
            //添加任务
            m_tasks.emplace([task]() { (*task)(); });
        }
        //随机唤醒一个线程
        m_condition.notify_one();

        return res;
    }

    ~ThreadPool()
    {
        stop_all();
    }

private:
    ThreadPool(size_t min_threads,size_t max_thread,size_t timeout_ms = 60000)
        :m_max_threads(max_thread),m_timeout_ms(timeout_ms), m_stop(false)
    {
        //创建
        for (size_t i = 0; i < min_threads; ++i)
        {
            add_thread();
        }
    }

    //添加线程
    void add_thread()
    {
        //创建工作线程
        std::shared_ptr<Worker> worker = std::make_shared<Worker>();
        //默认状态 闲置
        worker->status = Status::Idle;
        worker->thread = std::thread([this,worker]() {
            for(;;)
            {
                std::function<void()> task;
                {
                    // std::unique_lock lock(this->m_queue_mutex);
                    // this->m_condition.wait(lock, [this] {
                    //     return this->m_stop || !this->m_tasks.empty();
                    // });

                    //当前线程超时处理(相对时间)  超时返回false
                    std::unique_lock<std::mutex> lock(this->m_queue_mutex);
                    if (!this->m_condition.wait_for(lock, std::chrono::milliseconds(this->m_timeout_ms), [this] {
                            return this->m_stop || !this->m_tasks.empty();
                        }))
                    {
                        // 超时处理 大于最大线程数,标记移除  否则,跳过
                        if(this->m_workers.size() > this->m_max_threads)
                        {
                            auto it = this->m_workers.begin();
                            while (it != this->m_workers.end())
                            {
                                // 使用循环来移除所有状态为 Status::Removed 的 Worker
                                if ((*it)->status == Status::Removed)
                                {
                                    it = m_workers.erase(it);  // erase 返回下一个有效迭代器
                                }
                                else
                                {
                                    ++it;
                                }
                            }

                            worker->status = Status::Removed;
                            return;
                        }
                        else
                        {
                            continue;
                        }
                    }

                    //设置线程运行状态 Running
                    worker->status = Status::Running;
                    //退出条件
                    if (this->m_stop && this->m_tasks.empty())
                    {
                        return;
                    }
                    //获取任务
                    task = std::move(this->m_tasks.front());
                    this->m_tasks.pop();

                }
                //运行函数
                task();

                //函数运行完后 设置 闲置状态
                worker->status = Status::Idle;
            }
        });
        m_workers.push_back(std::move(worker));
    }

    //停止所有线程
    void stop_all()
    {
        {
            std::lock_guard <std::mutex> lock(m_queue_mutex);
            m_stop = true;
        }
        m_condition.notify_all();
        for(auto& worker: m_workers)
        {
            if(worker->thread.joinable())
            {
                worker->thread.join();
            }
        }
        m_workers.clear();
    }

    enum class Status
    {
        Running, //运行
        Idle,   //闲置
        Removed //移除
    };
    //工作线程
    struct Worker
    {
        Status status;
        std::thread thread;
    };

    std::queue<std::function<void()>> m_tasks;//任务队列
    std::vector<std::shared_ptr<Worker>> m_workers; //工作线程列表
    std::mutex m_queue_mutex;
    std::condition_variable m_condition;
    std::atomic<bool>  m_stop;
    size_t m_max_threads;  //最大线程数,闲置时 当前线程列表超出最大线程数时,标记 Removed  弹性容量
    size_t m_timeout_ms; //超时 就把当前线程标记 Removed,下一轮添加任务把这个线程移除
};


#endif // THREADPOOL_H




//普通函数
int sub(int a,int b)
{
    std::cout<<"sub: "<< a<<" - "<< b<<std::endl;

    int i=0;
    while((i++) != 200)
    {
        std::this_thread::sleep_for(std::chrono::microseconds(1));
        //耗时运算
    }

    return a - b;
}


//类函数
class Add
{
public:

    int add(int a,int b)
    {
        std::cout<<"Add::add: "<< a<<" + "<< b<<std::endl;

        int i=0;
        while((i++) != 200)
        {
            std::this_thread::sleep_for(std::chrono::microseconds(1));
            //耗时运算
        }
        return a+b;
    }
};

int main(int argc, char *argv[])
{
    int a = 60;
    int b = 20;

    int c = 30;
    int d = 40;

    //线程池运算 支持获取结果

    //######################################### 普通函数使用 #########################################//
    //方法1 - 直接调用
    std::future<int> res = ThreadPool::getInstance().enqueue(sub,a,b);
    //方法2 - lambda调用
    std::future<int> res1 = ThreadPool::getInstance().enqueue([a,b](){
        return sub(a,b);
    });
    //方法3 - std::bind绑定调用
    //绑定
    auto func = std::bind(sub,a,b);
    std::future<int> res2 = ThreadPool::getInstance().enqueue(func);

    std::cout<<"res:"<<res.get()<<std::endl;//先打印 res 然后等待线程函数结果
    std::cout<<"res1:"<<res1.get()<<std::endl;
    std::cout<<"res2:"<<res2.get()<<std::endl;



    //######################################### 类函数使用 #########################################//
    Add ddd;
    //方法1 - 直接调用
    std::future<int> res8 = ThreadPool::getInstance().enqueue(&Add::add,ddd,c,d);
    //方法2 - lambda调用
    std::future<int> res9 = ThreadPool::getInstance().enqueue([&ddd,c,d](){
        return ddd.add(c,d);
    });
    //方法3 - std::bind绑定调用
    //绑定
    auto func2 = std::bind(&Add::add,&ddd,c,d);
    std::future<int> res10 = ThreadPool::getInstance().enqueue(func2);

    std::cout<<"res8:"<<res8.get()<<std::endl;//先打印 res8 然后等待线程函数结果
    std::cout<<"res9:"<<res9.get()<<std::endl;

    std::cout<<"res10:"<<res10.get()<<std::endl;


    return 1;
}

运算结果

res:sub: 60 - 20
sub: 60sub:  - 2060
 - 20
40
res1:40
res2:40
Add::add: res8:30Add::add:  + 4030
Add::add: 30 +  + 4040

70
res9:70
res10:70

你可能感兴趣的:(开发语言,c++)