C++开发工具库

轮子?

Posted by 大狗 on July 21, 2021

C++开发工具库

1 并发多线程map

坦白讲核心思想都是一样子的,多个bucket,然后每个bucket下面挂一个红黑树(实际上就是stl的map),红黑树使用boost的读写锁来保护,源代码里面我禁用了拷贝构造和赋值运算符,因为RA2(锁)不可拷贝是一个通识,如果想拷贝那么就注释调宏标记的地方,但是尽量建议使用shared_ptr来做这个事情。

#ifndef __CONCURRENT_MAP_H__
#define __CONCURRENT_MAP_H__

#include <map>
#include <vector>
#include <atomic>
#include <iostream>

#include <boost/thread/locks.hpp>
#include <boost/thread/shared_mutex.hpp>



/* 使用此代码时请注意,直接用.h文件,不要写.cpp,使用置入式模型*/
using std::map;
using namespace boost;

/* bucket的代码没什么难度,只需要注意一点,更新和插入的区别。 更新是原先必须有这个值 */
/* 因为每次查询更新数据都只会锁住一个元素,不会锁住多个元素,所以不会产生死锁 */
template<class KEY_T, class VALUE_T, class Compare>
class ConcurrentBucket {
    public:
        ConcurrentBucket() {
        }

        virtual ~ConcurrentBucket() {
            write_lock lock(rwlock_);
            map_.clear();
        }
        bool Lookup(const KEY_T &key, VALUE_T &value) {
            bool ret = false;
            /* 锁单独的生命周期要短,最好不要和原子变量混在一起 */
            read_lock lock(rwlock_);
            typename std::map<KEY_T, VALUE_T>::iterator find = map_.find(key);
            if (find != map_.end()) {
                value = (*find).second;   /* 从生成的代码来看,编译器会生成一个默认的赋值运算符函数,每个对象都会执行拷贝 */
                ret = true;
            }
            return ret;
        }

        uint64_t Size() {
            read_lock lock(rwlock_);
            return map_.size();
        }

        void Clear() {
            write_lock lock(rwlock_);
            map_.clear();
        }

        bool Contain(const KEY_T &key) {
            bool ret = false;
            /* 锁单独的生命周期要短,最好不要和原子变量混在一起 */
            read_lock lock(rwlock_);
            typename std::map<KEY_T, VALUE_T>::iterator find = map_.find(key);
            if (find != map_.end()) {
                ret = true;
            }
            return ret;
        }
        bool Insert(const KEY_T &key, const VALUE_T &value)  {
            bool ret = false;
            {
                write_lock lock(rwlock_);
                ret = InsertWithoutLock(key, value);
            }
            return ret;
        }

        bool Update(const KEY_T &key, const VALUE_T &value)  {
            bool ret = false;
            {
                write_lock lock(rwlock_);
                ret = UpdateWithoutLock(key, value);
            }
            return ret;
        }


        void Remove(const KEY_T &key)  {
            {
                write_lock lock(rwlock_);
                RemoveWithoutLock(key);
            }
        }

        void GetAllKey(std::vector<std::pair<KEY_T, VALUE_T> > &list)  {
            read_lock lock(rwlock_);
            typename std::map<KEY_T, VALUE_T>::iterator iter = map_.begin();
            for ( ; iter != map_.end(); ++iter) {
                list.push_back(std::pair<KEY_T, VALUE_T>(iter->first, iter->second));
            }
            return;
        }

        void UpdateKeyBatch(std::vector<std::pair<KEY_T, VALUE_T> > &list)  {
            {
                write_lock lock(rwlock_);
                typename std::vector<std::pair<KEY_T, VALUE_T> >::iterator iter = list.begin();
                for (; iter != list.end(); ++iter) {
                    UpdateWithoutLock(iter->first, iter->second);
                }
            }
            return;
        }



        void InsertKeyBatch(std::vector<std::pair<KEY_T, VALUE_T> > &list)  {
            {
                write_lock lock(rwlock_);
                typename std::vector<std::pair<KEY_T, VALUE_T> >::iterator iter = list.begin();
                for (; iter != list.end(); ++iter) {
                    InsertWithoutLock(iter->first, iter->second);
                }
            }
            return;
        }
 

        void RemoveKeyBatch(std::vector<KEY_T> &list)  {
            {
                write_lock lock(rwlock_);
                typename std::vector<KEY_T>::iterator iter = list.begin();
                for (; iter != list.end(); ++iter) {
                    RemoveWithoutLock((*iter));   
                }
            }
            return;
        }

        void Echo() const {
            std::cout << " oh ho" << std::endl;
        }      
    private:

        void RemoveWithoutLock(const KEY_T &key)  {
            typename std::map<KEY_T, VALUE_T>::iterator find = map_.find(key);
            if (find != map_.end()) {
                map_.erase(find);
            }
            return;
        }
        bool InsertWithoutLock(const KEY_T &key, const VALUE_T &value) {
            bool ret = false;
            typename std::map<KEY_T, VALUE_T>::iterator find = map_.find(key);
            if (find == map_.end()) {
                map_.insert(std::pair<KEY_T, VALUE_T>(key, value));
                ret = true;
            } else {
                if (Compare()(value, find->second)) {          
                    map_.erase(find);
                    map_.insert(std::pair<KEY_T, VALUE_T>(key, value));
                    ret = true;
                }
            }

            return ret;
        }

        bool UpdateWithoutLock(const KEY_T &key, const VALUE_T &value) {
            bool ret = false;
            typename std::map<KEY_T, VALUE_T>::iterator find = map_.find(key);
            if (find != map_.end()) {
                if (Compare()(value, find->second)) {          
                    map_.erase(find);
                    map_.insert(std::pair<KEY_T, VALUE_T>(key, value));
                    ret = true;
                }
            }

            return ret;
        }



    private:
        typedef boost::shared_lock<boost::shared_mutex> read_lock;
        typedef boost::unique_lock<boost::shared_mutex> write_lock;
        
        std::map<KEY_T, VALUE_T> map_;
        /* using boost shared_mutex, boost version 1.6.9 */
        boost::shared_mutex rwlock_;
};

template<typename KEY_T, typename VALUE_T, typename Compare, typename Hash=std::hash<KEY_T> >
class ConcurrentMap {
    public:
        ConcurrentMap(uint64_t bucket_nums = 61, Hash const& hasher = Hash()) :
            buckets_(bucket_nums),
            hasher_ (hasher) {
            for (uint64_t i = 0; i < bucket_nums; ++i) {
                buckets_[i].reset(new ConcurrentBucket<KEY_T, VALUE_T, Compare>);
            }
        }

        bool LookUp(const KEY_T &key, VALUE_T &value) {
            return GetBucket(key).Lookup(key, value);
        }

        bool Contain(const KEY_T &key) {
            return GetBucket(key).Contain(key);
        }

        bool Insert(const KEY_T &key, const VALUE_T &value) {
            return GetBucket(key).Insert(key, value);
        }


        bool Update(const KEY_T &key, const VALUE_T &value)  {
            return GetBucket(key).Update(key, value);
        }

        bool Delete(const KEY_T &key) {
            GetBucket(key).Remove(key);
            return true;
        }

        void GetAllKey(std::vector<std::pair<KEY_T, VALUE_T> > &list)  {
            list.clear();
#if 0

            //auto iter = buckets_.begin();
            //这里要注意,先得对iter做解引用,得到uniqueptr,然后对unique_ptr做解引用或者用->才能进行调用
            typename std::vector<std::unique_ptr<ConcurrentBucket<KEY_T, VALUE_T, Compare> > >::iterator iter = buckets_.begin();
            for ( ; iter != buckets_.end(); ++iter ) {
                (*iter)->GetAllKey(list);
            }
#endif

            for (size_t bucket_index = 0; bucket_index < buckets_.size(); ++bucket_index) {
                buckets_[bucket_index]->GetAllKey(list);
            }

            return;
        }

        
        void UpdateKeyBatch(std::vector<std::pair<KEY_T, VALUE_T> > &list)  {
            /* 更新数组的个数需要和桶的个数一致 */
            std::vector<std::vector<std::pair<KEY_T, VALUE_T> > > update_lists(buckets_.size());

            /* 将更新的元素丢入数组之中 */
            typename std::vector<std::pair<KEY_T, VALUE_T> >::iterator iter = list.begin();
            for (; iter != list.end(); ++iter) {
                std::size_t  bucket_index = hasher_(iter->first) % buckets_.size();
                update_lists[bucket_index].push_back(std::pair<KEY_T, VALUE_T>(iter->first, iter->second));
            }

            /* 将对应的元素更新到对应的bucket里面 */
            for (size_t bucket_index = 0; bucket_index < buckets_.size(); ++bucket_index) {
                buckets_[bucket_index]->UpdateKeyBatch(update_lists[bucket_index]);
            }
        }

        void RemoveKeyBatch(std::vector<KEY_T> &list)  {
            /* 删除数组的个数需要和桶的个数一致 */
            std::vector<std::vector<KEY_T> > remove_lists(buckets_.size());

            /* 将删除的元素丢入数组之中 */
            typename std::vector<KEY_T>::iterator iter = list.begin();
            for (; iter != list.end(); ++iter) {
                std::size_t  bucket_index = hasher_(*iter) % buckets_.size();
                remove_lists[bucket_index].push_back(*iter);
            }

            /* 将对应的元素更新到对应的bucket里面 */
            for (size_t bucket_index = 0; bucket_index < buckets_.size(); ++bucket_index) {
                buckets_[bucket_index]->RemoveKeyBatch(remove_lists[bucket_index]);
            }
        }

        
        void InsertKeyBatch(std::vector<std::pair<KEY_T, VALUE_T> > &list)  {
            /* 更新数组的个数需要和桶的个数一致 */
            std::vector<std::vector<std::pair<KEY_T, VALUE_T> > > Insert_lists(buckets_.size());

            /* 将更新的元素丢入数组之中 */
            typename std::vector<std::pair<KEY_T, VALUE_T> >::iterator iter = list.begin();
            for (; iter != list.end(); ++iter) {
                std::size_t  bucket_index = hasher_(iter->first) % buckets_.size();
                Insert_lists[bucket_index].push_back(std::pair<KEY_T, VALUE_T>(iter->first, iter->second));
            }

            /* 将对应的元素更新到对应的bucket里面 */
            for (size_t bucket_index = 0; bucket_index < buckets_.size(); ++bucket_index) {
                buckets_[bucket_index]->InsertKeyBatch(Insert_lists[bucket_index]);
            }
        }


        ConcurrentBucket<KEY_T, VALUE_T, Compare>& GetBucket(KEY_T const& key) const {
            std::size_t const bucket_index = hasher_(key)% buckets_.size();
            return *buckets_[bucket_index];
        }

        uint64_t Size() {
            uint64_t total_size = 0;
            for (size_t i = 0 ; i < buckets_.size(); ++i) {
                total_size += buckets_[i]->Size();
            }
            return total_size;
        }

        void Clear() {
            for (size_t i = 0 ; i < buckets_.size(); ++i) {
                buckets_[i]->Clear();
            } 
        }

        /* 禁用这两者来保证绝对的安全,主要是本身mutex就是禁止拷贝的*/
        /* 禁用拷贝构造 */
        ConcurrentMap(ConcurrentMap const &other) = delete;

        /* 禁用赋值运算符 */
        ConcurrentMap& operator=(ConcurrentMap const &other) = delete;

#if 0

        /* 线程安全的拷贝构造函数 */
        ConcurrentMap(ConcurrentMap const &other) :
            buckets_(other.buckets_.size()), 
            hasher_(other.hasher_) {

            for (uint64_t i = 0; i < buckets_.size(); ++i) {
                buckets_[i].reset(new ConcurrentBucket<KEY_T, VALUE_T, Compare>);
            }

            std::vector<std::pair<KEY_T, VALUE_T> > list;
            other.GetAllKey(list);

            InsertKeyBatch(list);
        }

        ConcurrentMap& operator=(ConcurrentMap const &other) {
            /* 务必保证两个bucket的大小一致,不提供可伸缩的bucket*/
            assert(buckets_.size() == other.buckets_.size());

            hasher_ = other.hasher_;

            for (uint64_t i = 0; i < buckets_.size(); ++i) {
                buckets_[i].reset(new ConcurrentBucket<KEY_T, VALUE_T, Compare>);
            }

            std::vector<std::pair<KEY_T, VALUE_T> > list;
            other.GetAllKey(list);
            InsertKeyBatch(list);
        }
#endif

    private:
        /* 使用unique_ptr保证内存安全 */
        std::vector<std::unique_ptr<ConcurrentBucket<KEY_T, VALUE_T, Compare> > > buckets_;
        Hash hasher_;
};


#endif

2 并发LRU实现

实际上并发LRU的实现和并发MAP是非常相似的,并发MAP是挂了一堆的BUCKET,而并发LRU实际上也是挂了一堆的LRUBucket,然后每个找到对应的LRU。但是这个严格说,并不是完全的LRU,因为它拆分了几个不同的链表出来。我目前在用的时候使用的还是最简单的LRU,性能比较低但是严格。

#ifndef __LRU_H__
#define __LRU_H__


#include <boost/thread/locks.hpp>
#include <boost/thread/shared_mutex.hpp>

#include <map>
#include <list>
#include <vector>
#include <string>
#include <unordered_map>


template<typename KEY_T, typename VALUE_T>
class LRU {
    
    public:
        LRU(uint64_t capacity = 31) : 
            capacity_(capacity),
            size_(0) {
        }
        void Put(const KEY_T &key, const VALUE_T &value) {
            write_lock lock(rwlock_);
            auto iter = map_.find(key);
            if (iter != map_.end()) {
                touch(iter);
            } else {
                if (map_.size() == capacity_) {
                    map_.erase(list_.back());
                    list_.pop_back();
                }
                list_.push_front(key);
            }
            map_[key] = {value, list_.begin()};
        }
        bool Get(const KEY_T &key, VALUE_T &value) {
            write_lock lock(rwlock_);
            auto iter = cache_.find(key);
            if (iter == cache_.end()) {
                return false;
            }
            touch(iter);
            value = iter->second.first
            return true;
        }
        /* 没有的话返回一个新的VALUE_T回去,此时缓存是没有的 */
        /* 如果有的话放到最开始的地方 */
        VALUE_T Get(const KEY_T &key) {
            write_lock lock(rwlock_);
            auto iter = cache_.find(key);
            if (iter == cache_.end()) {
                return VALUE_T();
            }
            touch(iter);
            return iter->second.first;
        }
        bool GetAllKey(std::vector<KEY_T, VALUE_T> &res) {
            read_lock lock(rwlock_);
            for (auto iter : map_) {
                res.push_back(std::pair<KEY_T, VALUE_T> (iter.first, iter.second.second));
            }
            return true;
        }
    private:
        void touch(typename std::map<KEY_T, std::pair<VALUE_T, typename std::list<KEY_T>::iterator> >::iterator iter) {
            KEY_T key = iter->first;
            list_.erase(iter->second.second);
            list_.push_front(key);
            iter->second.second = list_.begin();
        }

    private:
        typedef boost::shared_lock<boost::shared_mutex> read_lock;
        typedef boost::unique_lock<boost::shared_mutex> write_lock;
        
        uint64_t capacity_;
        uint64_t size_;

        boost::shared_mutex rwlock_;

        /* iterator前面需要加一个typename */
        std::list<KEY_T> list_;
        std::map<KEY_T, std::pair<VALUE_T, typename std::list<KEY_T>::iterator> > map_;
};


template<typename KEY_T, typename VALUE_T, typename Hash=std::hash<KEY_T> >
class ConcurrentLRU {
    public:
        ConcurrentLRU(uint64_t bucket_nums = 11, Hash const& hasher = Hash()) :
            buckets_(bucket_nums),
            hasher_ (hasher) {
            for (uint64_t i = 0; i < bucket_nums; ++i) {
                buckets_[i].reset(new LRU<KEY_T, VALUE_T>);
            }
        }

        void Put(const KEY_T &key, const VALUE_T &value) {
            GetBucket(key).Put(key, value);
            return;
        }
        bool Get(const KEY_T &key, VALUE_T &value) {
            return GetBucket(key).Get(key, value);
        }
        /* 没有的话返回一个新的VALUE_T回去,此时缓存是没有的 */
        /* 如果有的话放到最开始的地方 */
        VALUE_T Get(const KEY_T &key) {
            return GetBucket(key).Get(key);
        }        

        void GetAllKey(std::vector<std::pair<KEY_T, VALUE_T> > &list)  {
            list.clear();
#if 0

            //auto iter = buckets_.begin();
            //这里要注意,先得对iter做解引用,得到uniqueptr,然后对unique_ptr做解引用或者用->才能进行调用
            typename std::vector<std::unique_ptr<LRU<KEY_T, VALUE_T> > >::iterator iter = buckets_.begin();
            for ( ; iter != buckets_.end(); ++iter ) {
                (*iter)->GetAllKey(list);
            }
#endif

            for (size_t bucket_index = 0; bucket_index < buckets_.size(); ++bucket_index) {
                buckets_[bucket_index]->GetAllKey(list);
            }

            return;
        }

        LRU<KEY_T, VALUE_T>& GetBucket(KEY_T const& key) const {
            std::size_t const bucket_index = hasher_(key)% buckets_.size();
            return *buckets_[bucket_index];
        }

        /* 禁用拷贝构造 */
        ConcurrentLRU(ConcurrentLRU const &other) = delete;

        /* 禁用赋值运算符 */
        ConcurrentLRU& operator=(ConcurrentLRU const &other) = delete;

    private:
        /* 使用unique_ptr保证内存安全 */
        std::vector<std::unique_ptr<LRU<KEY_T, VALUE_T> > > buckets_;
        Hash hasher_;
};


#endif

3 一个线程安全的连接池

下面这个连接池在写的时候由于需要证书更新,因此需要使用TLSSocketWrapper来包裹证书/密钥/ca证书。这里要注意为了实现每次握手的时候只使用新的证书,所以这里会加一个证书的时间戳,来保证用的都是新的证书。可以发现这样子即使证书握手失败也依然会拿新的证书来执行握手。

该连接池首先执行tls握手,然后将tls握手的socket(boost::shared_ptr<ssl::stream<ip::tcp::socket> >)丢到LRU中,来保证lru总是缓存了最新的值,然后发送数据失败的情况下会重新建立连接。为了保活使用了http get对连接做get来探活,每15秒执行一次探活 。

头文件:

#ifndef __SYNC_TLS_CLIENT_H__
#define __SYNC_TLS_CLIENT_H__

#include <iostream>
#include <string>

#include "lru.h"
#include "singleton.h"
#include "repeated_timer.h"

#include <string>
#include <istream>
#include <ostream>
#include <cstdlib>
#include <string>
#include <atomic>

#include <boost/asio.hpp>
#include <boost/asio/ssl.hpp>



#include "boost/thread.hpp"
#include "boost/function.hpp"
#include "boost/shared_ptr.hpp"
#include "boost/move/unique_ptr.hpp"
#include "boost/make_shared.hpp"
#include "boost/lockfree/spsc_queue.hpp"


#include <boost/beast/core.hpp>
#include <boost/beast/http.hpp>
#include <boost/beast/version.hpp>
#include <boost/asio/connect.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/ssl/error.hpp>
#include <boost/asio/ssl/stream.hpp>

using tcp = boost::asio::ip::tcp;       // from <boost/asio/ip/tcp.hpp>
namespace ssl = boost::asio::ssl;       // from <boost/asio/ssl.hpp>
namespace http = boost::beast::http;    // from <boost/beast/http.hpp>


using namespace boost::asio;

#define HTTP_11_VERSION 11
#define DEFAULT_HTTP_ALIVE_CONFIG "timeout=60, max=1000"
#define CONTENT_TYPE_JSON "application/json; charset=utf-8" 
#define CONTENT_TYPE_PLAIN "text/plain; charset=utf-8"
#define CONTENT_TYPE_HTML "text/html; charset=utf-8"
#define CONNECT_MAX_TRY_TIMES 4


class TLSSocketWrapper {
    public:
        /* 连接函数被封装到了TLSSocketWrapper的构造函数里面 */
        TLSSocketWrapper(const std::string &ip, const std::string &port, boost::asio::io_context &global_io_context, const std::string &cert, const std::string &key, const std::string &ca_cert, uint64_t time_stamp);

        std::string GetCert() const;
        std::string GetKey() const;
        std::string GetCaCert() const;
        uint64_t GetTimeStamp() const;
        bool HandshakeFinished() const;

        /* 下面两个不是线程安全的 */
        int GetMessageCore(const std::string &server_path, std::string &header, std::string &body);
        int PostMessageCore(const std::string &server_path, const std::string &request_content_type, const std::string &request_body, std::string &header, std::string &body);

    


    private:
        typedef boost::shared_lock<boost::shared_mutex> read_lock;
        typedef boost::unique_lock<boost::shared_mutex> write_lock;

        boost::unique_ptr<ssl::stream<ip::tcp::socket> > real_socket_;
        boost::shared_mutex rwlock_;

        std::string ip_;
        std::string port_;
        std::string cert_;
        std::string key_;
        std::string ca_cert_;
        uint64_t time_stamp_;
};


class TlsWrapperCompare {
    public:
        bool operator() (const boost::shared_ptr<TLSSocketWrapper> &a, const boost::shared_ptr<TLSSocketWrapper> &b) {
            return a->GetTimeStamp() >= b->GetTimeStamp();
        }
};

class SyncTlsConnectionManager : public hxn::Singleton<SyncTlsConnectionManager> {

    public:
        /* 目前,该实例负责维护数据并保活 */
        void Init();
        void HttpHeartBeat();
        io_context& GetGlobalIOContext();
        ConcurrentLRU<std::string, boost::shared_ptr<TLSSocketWrapper>, TlsWrapperCompare>& GetGlobalSocketWrapper();

    private:
        /* 第一个io context是给所有的tls连接用的 */
        boost::asio::io_context global_io_context_;
        ConcurrentLRU<std::string, boost::shared_ptr<TLSSocketWrapper>, TlsWrapperCompare> global_sockets_;


        /* 异步定时器+线程, 使用boost asio + chrono实现*/
        boost::shared_ptr<boost::thread> loop_;
        boost::shared_ptr<RepeatedTimer> timer_;
        /* 该context负责给全局定时器使用 */
        boost::shared_ptr<io_service> io_service_;

};



/* asio socket的同步读写是线程安全的 */
class SyncTlsClient {
    public:
        /* 探活/保活接口专用构造函数 */
        SyncTlsClient(const std::string &ip, const std::string &port, const boost::shared_ptr<TLSSocketWrapper> &connection);
        /* 时间戳为证书时间戳,或者说证书文件的时间戳 */
        SyncTlsClient(const std::string &ip, const std::string &port, const std::string &cert, const std::string &key, const std::string &ca_cert, const uint64_t &time_stamp);

        /* http心跳报文实际上是个get请求 */
        int HeartBeat(const std::string &server_path, std::string &response_header, std::string &response_body);
        int GetMessage(const std::string &server_path, std::string &response_header, std::string &response_body);
        int PostMessage(const std::string &server_path, const std::string &request_content_type, const std::string &request_body, std::string &response_header, std::string &response_body);

    private:
        boost::shared_ptr<TLSSocketWrapper> socket_wrapper_;
        /* 可以是ip/port也可能是host/port */
        std::string ip_;
        std::string port_;
};
#endif

cpp文件

#include "sync_tls_client.h"

/* 构造函数只会在每个线程构造所以必然是线程安全的 */
TLSSocketWrapper::TLSSocketWrapper(const std::string &ip, const std::string &port, boost::asio::io_context &global_io_context, const std::string &cert, const std::string &key, const std::string &ca_cert, uint64_t time_stamp) :
    ip_(ip),
    port_(port),
    cert_(cert),
    key_(key),
    ca_cert_(ca_cert),
    time_stamp_(time_stamp) {

    int try_times = 0;

    while ((real_socket_ == nullptr) && (try_times < CONNECT_MAX_TRY_TIMES)) {
        try {
            ssl::context ssl_context{ssl::context::sslv23_client};

            /* 两者都不为空, 才执行load */
            if (!(cert.empty() || key.empty())) {
                ssl_context.use_certificate_chain(buffer(cert.data(), cert.size()));
                ssl_context.use_private_key(buffer(key.data(), key.size()), ssl::context::file_format::pem);
            }

            if (!ca_cert.empty()) {
                ssl_context.add_certificate_authority(buffer(ca_cert.data(), ca_cert.size()));
            }
            
            real_socket_.reset(new ssl::stream<tcp::socket>(global_io_context, ssl_context));

            boost::system::error_code ip_invalid;
            ip::address ip_address = ip::address::from_string(ip, ip_invalid);
            if (ip_invalid) {
                /* url */
                tcp::resolver resolver{global_io_context};
                auto const results = resolver.resolve(ip, port);
                connect(real_socket_->next_layer(), results.begin(), results.end());
            } else {
                tcp::endpoint endpoint(ip_address, std::stoi(port));
                real_socket_->next_layer().connect(endpoint);
            }

            /* keep下层socket alive */
            real_socket_->next_layer().set_option(socket_base::keep_alive(true));
            real_socket_->handshake(ssl::stream_base::client);
        } catch(std::exception const& e) {
            std::cerr << "Handshake Error: " << e.what() << std::endl;
            real_socket_ = nullptr;
        }
        ++try_times;
    }
}

std::string TLSSocketWrapper::GetCert() const {
    return cert_;
}

std::string TLSSocketWrapper::GetKey() const {
    return key_;
}

uint64_t TLSSocketWrapper::GetTimeStamp() const {
    return time_stamp_;
}

bool TLSSocketWrapper::HandshakeFinished() const {
    return real_socket_ != nullptr;
}

std::string TLSSocketWrapper::GetCaCert() const {
    return ca_cert_;
}

/* 返回的错误码是200表示get成功了,其它代表失败 */
int TLSSocketWrapper::GetMessageCore(const std::string &server_path, std::string &header, std::string &body) {
    int error_code;
    try {
        http::request<http::string_body> req(http::verb::get, server_path, HTTP_11_VERSION);
        req.set(http::field::host, ip_);
        req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING);
        req.set(http::field::keep_alive, DEFAULT_HTTP_ALIVE_CONFIG);

        boost::beast::flat_buffer buffer;
        http::response<http::dynamic_body> res;

        /* 保持锁的时间最短 */
        {
            write_lock lock(rwlock_);
            http::write(*real_socket_, req);
            http::read(*real_socket_, buffer, res);
        }
       

        std::string response_body{ buffers_begin(res.body().data()),buffers_end(res.body().data()) };

        body = response_body;
        //http错误码是返回值
        error_code = res.result_int();
    } catch(std::exception const& e) {
        std::cerr << "Get Message Error: " << e.what() << std::endl;
        error_code = 500;
    }
    return error_code;
}


/* 返回的错误码是200表示get成功了,其它代表失败 */
int TLSSocketWrapper::PostMessageCore(const std::string &server_path, const std::string &request_content_type, const std::string &request_body, std::string &header, std::string &body) {
    int error_code;
    try {

        http::request<http::string_body> req(http::verb::post, server_path, HTTP_11_VERSION);
        req.set(http::field::host, ip_);
        req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING);
        req.set(http::field::keep_alive, DEFAULT_HTTP_ALIVE_CONFIG);
        req.set(http::field::content_type, request_content_type);
        req.body() = request_body;
        /* 设定完了使用prepare_payload来更新header里面的content_length */
        req.prepare_payload();

        boost::beast::flat_buffer buffer;
        http::response<http::dynamic_body> res;

        /* 保持锁的时间最短 */
        {
            write_lock lock(rwlock_);
            http::write(*real_socket_, req);
            http::read(*real_socket_, buffer, res);
        }

        std::string response_body{ buffers_begin(res.body().data()),buffers_end(res.body().data()) };

        body = response_body;
        //http错误码是返回值
        error_code = res.result_int();

    } catch(std::exception const& e) {
        std::cerr << "Post Message Error: " << e.what() << std::endl;
        error_code = 500;
    }
    return error_code;
}

void SyncTlsConnectionManager::Init() {
    io_service_ = boost::shared_ptr<io_service>(new io_service());

    /* 添加work防止io_service 在pending时退出*/
    io_service::work work{*io_service_};
    this->loop_ = boost::shared_ptr<boost::thread> (new boost::thread([&]{ io_service_->run(); }));

    /* lambda需要捕获this才能调用this的函数,如果是静态或者全局函数那么不需要执行捕获 */
    timer_ = boost::shared_ptr<RepeatedTimer> (new RepeatedTimer(*io_service_, [this](const boost::system::error_code &e) { this->HttpHeartBeat();}));

    /* 15秒钟运行一次心跳保活线程 */
    timer_->Start(/*ms=*/1000*15);
}

void SyncTlsConnectionManager::HttpHeartBeat() {
    std::vector<std::pair<std::string, boost::shared_ptr<TLSSocketWrapper> > > list;
    global_sockets_.GetAllKey(list);
    for (auto iter : list) {
        size_t pos = iter.first.find('#');
        if (pos == iter.first.npos) {
            /* 理论上不可能发生这问题 */
            continue;
        }
        std::string ip(iter.first.substr(0, pos));
        std::string port(iter.first.substr(pos+1));

        /*
        SyncTlsClient heartbeat(ip, port, iter.second);
        std::string header;
        std::string body;
        int error_code = heartbeat.HeartBeat("/monitor/alive", header, body);
        if (error_code != 200) {
        }
        */

    }
}

io_context& SyncTlsConnectionManager::GetGlobalIOContext() {
    return global_io_context_;
}

ConcurrentLRU<std::string, boost::shared_ptr<TLSSocketWrapper>, TlsWrapperCompare>& SyncTlsConnectionManager::GetGlobalSocketWrapper() {
    return global_sockets_;
}




/* 探活/保活接口专用构造函数 */
SyncTlsClient::SyncTlsClient(const std::string &ip, const std::string &port, const boost::shared_ptr<TLSSocketWrapper> &connection):
    ip_(ip),
    port_(port),
    socket_wrapper_(connection) {
        ;
}

/* 时间戳为证书时间戳,或者说证书文件的时间戳 */
SyncTlsClient::SyncTlsClient(const std::string &ip, const std::string &port, const std::string &cert, const std::string &key, const std::string &ca_cert, const uint64_t &time_stamp) :
    ip_(ip),
    port_(port) {
    std::string ref = ip+"#"+port;

    if (SyncTlsConnectionManager::instance().GetGlobalSocketWrapper().Get(ref, socket_wrapper_)) {
        /* 能查到且时间戳没超时意味着可用 */
        if (time_stamp <= socket_wrapper_->GetTimeStamp()) {
            /* 不需要执行重新握手的唯一情况只有这一种,直接使用缓存数据 */
            return;
        }
    }

    boost::shared_ptr<TLSSocketWrapper> new_socket_wrapper = boost::shared_ptr<TLSSocketWrapper> (new TLSSocketWrapper(ip, port, SyncTlsConnectionManager::instance().GetGlobalIOContext(), cert, key, ca_cert, time_stamp));
    if (new_socket_wrapper != nullptr && new_socket_wrapper->HandshakeFinished()) {
        /* 建立连接失败,不会更新全局缓存的,全局缓存只在握手成功的情况下更新*/
        SyncTlsConnectionManager::instance().GetGlobalSocketWrapper().Put(ref, new_socket_wrapper);
    }
    
    socket_wrapper_ = new_socket_wrapper;
}

/* http心跳报文实际上是个get请求 */
int SyncTlsClient::HeartBeat(const std::string &server_path, std::string &response_header, std::string &response_body) {
    int error_code;
    /* real_socket_理论来说不能为空,为空直接说明一开始尝试连接失败,直接返回错误码500即可 */
    if (socket_wrapper_ != nullptr && socket_wrapper_->HandshakeFinished()) {
        /* asio的is_open是个无效函数,因此只能通过尝试发送数据来探测是否该socket还活着/出现异常 */
        error_code = socket_wrapper_->GetMessageCore(server_path, response_header, response_body);
        /* 500代表连接失效了,需要重新连接,并发起请求 */
        if (error_code == 500) {
            /* 因为永远保证socket_wrapper的证书/私钥/时间戳都是最新版,所以使用socket_wrapper_的证书私钥建立连接 */
            boost::shared_ptr<TLSSocketWrapper> new_socket_wrapper = boost::shared_ptr<TLSSocketWrapper> (new TLSSocketWrapper(ip_, port_, SyncTlsConnectionManager::instance().GetGlobalIOContext(), socket_wrapper_->GetCert(), socket_wrapper_->GetKey(), socket_wrapper_->GetCaCert(), socket_wrapper_->GetTimeStamp()));
            if (new_socket_wrapper != nullptr && new_socket_wrapper->HandshakeFinished()) {
                error_code = new_socket_wrapper->GetMessageCore(server_path, response_header, response_body);
                if (error_code == 200) {
                    std::string ref = ip_+"#"+port_;
                    SyncTlsConnectionManager::instance().GetGlobalSocketWrapper().Put(ref, new_socket_wrapper);
                    socket_wrapper_ = new_socket_wrapper;
                }
            }
        }
    } else {
        error_code = 500;
    }
    
    return error_code;
}

int SyncTlsClient::GetMessage(const std::string &server_path, std::string &response_header, std::string &response_body) {
    int error_code;
    /* real_socket_理论来说不能为空,为空直接说明一开始尝试连接失败,直接返回错误码500即可 */
    if (socket_wrapper_ != nullptr && socket_wrapper_->HandshakeFinished()) {
        /* asio的is_open是个无效函数,因此只能通过尝试发送数据来探测是否该socket还活着/出现异常 */
        error_code = socket_wrapper_->GetMessageCore(server_path, response_header, response_body);
        /* 500代表连接失效了,需要重新连接,并发起请求 */
        if (error_code == 500) {
            /* 因为永远保证socket_wrapper的证书/私钥/时间戳都是最新版,所以使用socket_wrapper_的证书私钥建立连接 */
            boost::shared_ptr<TLSSocketWrapper> new_socket_wrapper = boost::shared_ptr<TLSSocketWrapper> (new TLSSocketWrapper(ip_, port_, SyncTlsConnectionManager::instance().GetGlobalIOContext(), socket_wrapper_->GetCert(), socket_wrapper_->GetKey(), socket_wrapper_->GetCaCert(), socket_wrapper_->GetTimeStamp()));
            if (new_socket_wrapper != nullptr && new_socket_wrapper->HandshakeFinished()) {
                error_code = new_socket_wrapper->GetMessageCore(server_path, response_header, response_body);
                if (error_code == 200) {
                    std::string ref = ip_+"#"+port_;
                    SyncTlsConnectionManager::instance().GetGlobalSocketWrapper().Put(ref, new_socket_wrapper);
                    socket_wrapper_ = new_socket_wrapper;
                }
            }
        }
    } else {
        error_code = 500;
    }
    return error_code;
}

int SyncTlsClient::PostMessage(const std::string &server_path, const std::string &request_content_type, const std::string &request_body, std::string &response_header, std::string &response_body) {
    int error_code;
    /* socket_理论来说不能为空,为空直接说明地址无效了,直接返回错误码500即可 */
    if (socket_wrapper_ != nullptr && socket_wrapper_->HandshakeFinished()) {
        /* asio的is_open是个无效函数,因此只能通过尝试发送数据来探测是否该socket还活着 */
        error_code = socket_wrapper_->PostMessageCore(server_path, request_content_type, request_body, response_header, response_body);
        /* 500代表连接失效了,需要重新连接,并发起请求 */
        if (error_code == 500) {
            /* 因为永远保证socket_wrapper的证书/私钥/时间戳都是最新版,所以使用socket_wrapper_的证书私钥建立连接 */
            boost::shared_ptr<TLSSocketWrapper> new_socket_wrapper = boost::shared_ptr<TLSSocketWrapper> (new TLSSocketWrapper(ip_, port_, SyncTlsConnectionManager::instance().GetGlobalIOContext(), socket_wrapper_->GetCert(), socket_wrapper_->GetKey(), socket_wrapper_->GetCaCert(), socket_wrapper_->GetTimeStamp()));
            if (new_socket_wrapper != nullptr && new_socket_wrapper->HandshakeFinished()) {
                error_code = new_socket_wrapper->PostMessageCore(server_path, request_content_type, request_body, response_header, response_body);
                if (error_code == 200) {
                    std::string ref = ip_+"#"+port_;
                    SyncTlsConnectionManager::instance().GetGlobalSocketWrapper().Put(ref, new_socket_wrapper);
                    socket_wrapper_ = new_socket_wrapper;
                }
            }
        }
    } else {
        error_code = 500;
    }
    
    return error_code;
}

5 OPENSSL签名验证签名代码

本来是指用crypto++库做的,然后惊讶的发现crypto++的签名/验签的性能低的令人发指。所以最后又采用了openssl的代码

头文件

#ifndef __SIGN_ALGORITHM_H__
#define __SIGN_ALGORITHM_H__

/* inf cbom的crypto后面没有pp,mac装的需要加个pp */
#ifndef MAC_BUILD
#include <crypto/eccrypto.h>
#include <crypto/osrng.h>
#include <crypto/rsa.h>
#include <crypto/pssr.h>
#include <crypto/base64.h>
#else
#include <cryptopp/eccrypto.h>
#include <cryptopp/osrng.h>
#include <cryptopp/rsa.h>
#include <cryptopp/pssr.h>
#include <cryptopp/base64.h>
#endif
#include <string>
#include <iostream>

#include "sts_enum.h"

#include "openssl/ecdsa.h"
#include "openssl/rsa.h"
#include "openssl/ssl.h"
#include "openssl/ecdh.h"
#include "openssl/evp.h"
#include "openssl/sha.h"
#include "openssl/bio.h"
#include "openssl/pem.h"
#include "openssl/err.h"
#include "openssl/hmac.h"

#define EVP_MAX_SIGNATURE_SIZE 1024

using std::string;

class SignAlgorithm{

    public:

        static bool sign_with_pem(const int &algorithm, const std::string &message, const std::string &key, std::string &signature);
        static bool verify_with_pem(const int &algorithm, const std::string &message, const std::string &key, std::string &signature);

        static bool base64_url_encode(const std::string &in, std::string &out);
        static bool base64_url_decode(const std::string &in, std::string &out);

        static bool base64_encode(const std::string &in, std::string &out);
        static bool base64_decode(const std::string &in, std::string &out);

        template <typename T>
        static bool load_base64_raw_key(const std::string &pem, T &key) {
            ByteQueue queue;
            Base64Decoder decoder;

            decoder.Attach(new Redirector(queue));
            decoder.Put((const byte*)pem.data(), pem.length());
            decoder.MessageEnd();

            try {
                key.BERDecode(queue);
        
#ifdef AUTH_SDK_VALIDATE_KEY
                AutoSeededRandomPool prng;
                bool valid = key.Validate(prng, 3);
                if (!valid) {
                    return false;
                }
#endif
            } catch (const Exception& ex) {
                return false;
            }
            return true;
        }

    private:

        /* 公私密钥均为pem格式*/
        static bool ecdsa_sha256_sign(const std::string &message, const std::string &pem_private_key, std::string &signature);
        static bool ecdsa_sha256_verify(const std::string &message, const std::string &pem_pub_key, std::string &signature);
        static bool ecdsa_sha256_sign_openssl(const std::string &message, const std::string &pem_private_key, std::string &signature);
        static bool ecdsa_sha256_verify_openssl(const std::string &message, const std::string &pem_pub_key, std::string &signature);

        static bool rsa_sha256_sign_openssl(const std::string &message, const std::string &pem_private_key, std::string &signature);
        static bool rsa_sha256_verify_openssl(const std::string &message, const std::string &pem_pub_key, std::string &signature);
        static bool rsa_sha256_sign(const std::string &message, const std::string &pem_private_key, std::string &signature) ;
        static bool rsa_sha256_verify(const std::string &message, const std::string &pem_pub_key, std::string &signature);

        static bool hmac_sha256_sign(const std::string &message, const std::string &key, std::string &signature);
        static bool hmac_sha256_verify(const std::string &message, const std::string &key, std::string &signature);
        static bool hmac_sha256_sign_openssl(const std::string &message, const std::string &key, std::string &signature);
        static bool hmac_sha256_verify_openssl(const std::string &message, const std::string &key, std::string &signature);

};



#endif

cpp文件

#include "sign_algorithm.h"

bool SignAlgorithm::sign_with_pem(const int &algorithm, const std::string &message, const std::string &key, std::string &signature) {
    switch (algorithm) {
#ifdef USE_CRYPTOPP_ALG
        case STS_ALG_TYPE_RS256:
            return SignAlgorithm::rsa_sha256_sign(message, key, signature);
        case STS_ALG_TYPE_ES256:
            return SignAlgorithm::ecdsa_sha256_sign(message, key, signature);
        case STS_ALG_TYPE_HS256:
            return SignAlgorithm::hmac_sha256_sign(message, key, signature);
#else
        case STS_ALG_TYPE_RS256:
            return SignAlgorithm::rsa_sha256_sign_openssl(message, key, signature);
        case STS_ALG_TYPE_ES256:
            return SignAlgorithm::ecdsa_sha256_sign_openssl(message, key, signature);
        case STS_ALG_TYPE_HS256:
            return SignAlgorithm::hmac_sha256_sign_openssl(message, key, signature);
#endif
    }
    return false;
}


bool SignAlgorithm::verify_with_pem(const int &algorithm, const std::string &message, const std::string &key, std::string &signature) {
    switch (algorithm) {
#ifdef USE_CRYPTOPP_ALG
        case STS_ALG_TYPE_RS256:
            return SignAlgorithm::rsa_sha256_verify(message, key, signature);
        case STS_ALG_TYPE_ES256:
            return SignAlgorithm::ecdsa_sha256_verify(message, key, signature);
        case STS_ALG_TYPE_HS256:
            return SignAlgorithm::hmac_sha256_verify(message, key, signature);
#else
        case STS_ALG_TYPE_RS256:
            return SignAlgorithm::rsa_sha256_verify_openssl(message, key, signature);
        case STS_ALG_TYPE_ES256:
            return SignAlgorithm::ecdsa_sha256_verify_openssl(message, key, signature);
        case STS_ALG_TYPE_HS256:
            return SignAlgorithm::hmac_sha256_verify_openssl(message, key, signature);
#endif
    }
    return false;
}



bool SignAlgorithm::ecdsa_sha256_sign(const std::string &message, const std::string &pem_private_key, std::string &signature) {

    CryptoPP::ECDSA<CryptoPP::ECP, CryptoPP::SHA256>::PrivateKey private_key;
    if (!load_base64_raw_key(pem_private_key, private_key)) {
        return false;
    }

    CryptoPP::AutoSeededRandomPool prng;
    CryptoPP::ECDSA<CryptoPP::ECP, CryptoPP::SHA256>::Signer signer(private_key);

    try {
        CryptoPP::StringSource s( message, true /* pump all */,
            new CryptoPP::SignerFilter(prng, 
                signer,
                new CryptoPP::StringSink( signature )
            )
        ); 
    } catch (CryptoPP::Exception e) {
        return false;
    }
    return true;
}

bool SignAlgorithm::ecdsa_sha256_verify(const std::string &message, const std::string &pem_pub_key, std::string &signature) {

    CryptoPP::ECDSA<CryptoPP::ECP, CryptoPP::SHA256>::PublicKey public_key;
    if (!load_base64_raw_key(pem_pub_key, public_key)) {
        return false;
    }

    bool ret = false;
    CryptoPP::ECDSA<CryptoPP::ECP, CryptoPP::SHA256>::Verifier verifier(public_key);

    try {
        CryptoPP::StringSource s( signature+message, true /* pump all */,
            /* 默认就是把校验结果放到ret里 */
            new CryptoPP::SignatureVerificationFilter(
                verifier, 
                new CryptoPP::ArraySink( (byte*)&ret, sizeof(ret) )
            )
        );
    } catch (CryptoPP::Exception e) {
        ret = false;
    }

    return ret;
}

bool SignAlgorithm::rsa_sha256_sign(const std::string &message, const std::string &pem_private_key, std::string &signature) {

    CryptoPP::RSA::PrivateKey private_key;
    if (!load_base64_raw_key(pem_private_key, private_key)) {
        return false;
    }

    CryptoPP::RSASS<CryptoPP::PSS, CryptoPP::SHA256>::Signer signer(private_key);
    CryptoPP::AutoSeededRandomPool prng;
    
    bool ret = true;
    try {
        CryptoPP::StringSource s(message, true, 
            new CryptoPP::SignerFilter(prng, signer,
                new CryptoPP::StringSink( signature )
            )
        ); 
    } catch (CryptoPP::Exception e) {
        ret = false;
    }
    return ret;
}

bool SignAlgorithm::rsa_sha256_verify(const std::string &message, const std::string &pem_pub_key, std::string &signature) {

    CryptoPP::RSA::PublicKey public_key;
    if (!load_base64_raw_key(pem_pub_key, public_key)) {
        return false;
    }

    CryptoPP::RSASS<CryptoPP::PSS, CryptoPP::SHA256>::Verifier verifier(public_key);

    bool ret = true;
    try {
        CryptoPP::StringSource s(signature+message, true,
            new CryptoPP::SignatureVerificationFilter(
                verifier, 
                new CryptoPP::ArraySink( (byte*)&ret, sizeof(ret) )
            )
        );
    } catch (CryptoPP::Exception e) {
        ret = false;
    }
    return ret;
}

bool SignAlgorithm::hmac_sha256_sign(const std::string &message, const std::string &key, std::string &signature) {
    bool ret = true;
    try {
        /* 输入的密钥是base64编码的,所以需要先解码 */
        std::string base64_key;
        SignAlgorithm::base64_decode(key, base64_key);
        CryptoPP::HMAC<CryptoPP::SHA256> hmac((byte*) base64_key.c_str(), base64_key.length());

        CryptoPP::StringSource(message, true, new CryptoPP::HashFilter(hmac, new CryptoPP::StringSink(signature)));
    } catch (CryptoPP::Exception e) {
        ret = false; 
    }
    return ret;
}

bool SignAlgorithm::hmac_sha256_verify(const std::string &message, const std::string &key, std::string &signature) {
    bool ret = true;
    std::string computed_signature;
    try {
        /* 输入的密钥是base64编码的,所以需要先解码 */
        std::string base64_key;
        SignAlgorithm::base64_decode(key, base64_key);
        CryptoPP::HMAC<CryptoPP::SHA256> hmac((byte*) base64_key.c_str(), base64_key.length());

        CryptoPP::StringSource(message, true, new CryptoPP::HashFilter(hmac, new CryptoPP::StringSink(computed_signature)));
        if (computed_signature.compare(signature) != 0) {
            ret = false;
        }
    } catch (CryptoPP::Exception e) {
        ret = false; 
    }
    return ret;
}

bool SignAlgorithm::base64_url_encode(const std::string &in, std::string &out) {
    bool ret = true;
    out.clear();
    try {
        CryptoPP::StringSource(in, true, new CryptoPP::Base64URLEncoder(new CryptoPP::StringSink(out)));
    } catch (CryptoPP::Exception e) {
        ret = false; 
    }
    return ret;
}

bool SignAlgorithm::base64_url_decode(const std::string &in, std::string &out) {
    bool ret = true;
    out.clear();
    try {
        CryptoPP::StringSource(in, true, new CryptoPP::Base64URLDecoder(new CryptoPP::StringSink(out)));
    } catch (CryptoPP::Exception e) {
        ret = false; 
    }
    return ret;
}


bool SignAlgorithm::base64_encode(const std::string &in, std::string &out) {
    bool ret = true;
    out.clear();
    try {
        CryptoPP::StringSource(in, true, new CryptoPP::Base64Encoder(new CryptoPP::StringSink(out)));
    } catch (CryptoPP::Exception e) {
        ret = false; 
    }
    return ret;
}

bool SignAlgorithm::base64_decode(const std::string &in, std::string &out) {
    bool ret = true;
    out.clear();
    try {
        CryptoPP::StringSource(in, true, new CryptoPP::Base64Decoder(new CryptoPP::StringSink(out)));
    } catch (CryptoPP::Exception e) {
        ret = false; 
    }
    return ret;
}

/* 签名的最后结果并不一定小于64,对ECDSA而言256bit可以达到72字节。所以signature_buf长度最好保存1024,直接防备rsa 4096/1024*8的密钥 */
bool SignAlgorithm::ecdsa_sha256_sign_openssl(const std::string &message, const std::string &pem_private_key, std::string &signature) {
    uint32_t digest_len = 0;
    uint32_t signature_len = 0;
    unsigned char digest[EVP_MAX_MD_SIZE];
    char signature_buf[EVP_MAX_SIGNATURE_SIZE];

    std::string key = "-----BEGIN PRIVATE KEY-----\n" + pem_private_key + "\n-----END PRIVATE KEY-----";
    BIO *bio = BIO_new(BIO_s_mem());
    if (bio == nullptr) {
        return false;
    }
    BIO_puts(bio, key.c_str());

    EC_KEY *ec_key = PEM_read_bio_ECPrivateKey(bio, NULL, NULL, NULL);;
    if (ec_key == nullptr) {
        BIO_free(bio);
        return false;
    }

    EVP_MD_CTX *md_ctx = EVP_MD_CTX_create();
    if (md_ctx == nullptr) {
        BIO_free(bio);
        EC_KEY_free(ec_key);
        return false;
    }

    bool ret = true;
    bzero(digest, EVP_MAX_MD_SIZE);
    if (EVP_DigestInit(md_ctx, EVP_sha256()) <= 0
        ||EVP_DigestUpdate(md_ctx, (const void *)message.c_str(), message.length()) <= 0
            ||EVP_DigestFinal(md_ctx, digest, &digest_len) <= 0
                ||ECDSA_sign(0, digest, digest_len, (unsigned char*)signature_buf, &signature_len, ec_key) <= 0) {
        ret = false;
    }


    if (ret == true) {
        std::string final_sig(signature_buf, signature_len);
        signature = final_sig;
    }


    BIO_free(bio);
    EC_KEY_free(ec_key);
    EVP_MD_CTX_destroy(md_ctx);

    return ret;

}
bool SignAlgorithm::ecdsa_sha256_verify_openssl(const std::string &message, const std::string &pem_pub_key, std::string &signature) {
    uint32_t digest_len = 0;
    int signature_len = signature.length();
    unsigned char digest[EVP_MAX_MD_SIZE];

    std::string key = "-----BEGIN PUBLIC KEY-----\n" + pem_pub_key + "\n-----END PUBLIC KEY-----";
    BIO *bio = BIO_new(BIO_s_mem());
    if (bio == nullptr) {
        return false;
    }

    BIO_puts(bio, key.c_str());

    EC_KEY *ec_key = PEM_read_bio_EC_PUBKEY(bio, NULL, NULL, NULL);;
    if (ec_key == nullptr) {
        BIO_free(bio);
        return false;
    }

    EVP_MD_CTX *md_ctx = EVP_MD_CTX_create();
    if (md_ctx == nullptr) {
        BIO_free(bio);
        EC_KEY_free(ec_key);
        return false;
    }


    bool ret = true;
    bzero(digest, EVP_MAX_MD_SIZE);

    if (EVP_DigestInit(md_ctx, EVP_sha256()) <= 0
        ||EVP_DigestUpdate(md_ctx, (const void *)message.c_str(), message.length()) <= 0
            ||EVP_DigestFinal(md_ctx, digest, &digest_len) <= 0
                ||ECDSA_verify(0, digest, digest_len, (const unsigned char*)signature.c_str(), signature_len, ec_key) <= 0) {
        
        
        ret = false;
    }

    BIO_free(bio);
    EC_KEY_free(ec_key);
    EVP_MD_CTX_destroy(md_ctx);

    return ret;
}

bool SignAlgorithm::rsa_sha256_sign_openssl(const std::string &message, const std::string &pem_private_key, std::string &signature) {
    uint32_t digest_len = 0;
    uint32_t signature_len = 0;
    unsigned char digest[EVP_MAX_MD_SIZE];
    char signature_buf[EVP_MAX_SIGNATURE_SIZE];

    std::string key = "-----BEGIN PRIVATE KEY-----\n" + pem_private_key + "\n-----END PRIVATE KEY-----";
    BIO *bio = BIO_new(BIO_s_mem());
    if (bio == nullptr) {
        return false;
    }
    BIO_puts(bio, key.c_str());

    RSA *rsa_key = PEM_read_bio_RSAPrivateKey(bio, NULL, NULL, NULL);;
    if (rsa_key == nullptr) {
        BIO_free(bio);
        return false;
    }

    EVP_MD_CTX *md_ctx = EVP_MD_CTX_create();
    if (md_ctx == nullptr) {
        BIO_free(bio);
        RSA_free(rsa_key);
        return false;
    }
    bool ret = true;

    if (EVP_DigestInit(md_ctx, EVP_sha256()) <= 0
        ||EVP_DigestUpdate(md_ctx, (const void *)message.c_str(), message.length()) <= 0
            ||EVP_DigestFinal(md_ctx, digest, &digest_len) <= 0
                ||RSA_sign(NID_sha256, digest, digest_len, (unsigned char*)signature_buf, &signature_len, rsa_key) <= 0) {
        ret = false;
    }


    if (ret == true) {
        std::string final_sig(signature_buf, signature_len);

        signature = final_sig;
    }
    BIO_free(bio);
    RSA_free(rsa_key);
    EVP_MD_CTX_destroy(md_ctx);
    return ret;
}

bool SignAlgorithm::rsa_sha256_verify_openssl(const std::string &message, const std::string &pem_pub_key, std::string &signature) {
    uint32_t digest_len = 0;
    int signature_len = signature.length();
    unsigned char digest[EVP_MAX_MD_SIZE];

    std::string key = "-----BEGIN PUBLIC KEY-----\n" + pem_pub_key + "\n-----END PUBLIC KEY-----";
    BIO *bio = BIO_new(BIO_s_mem());
    if (bio == nullptr) {
        return false;
    }
    BIO_puts(bio, key.c_str());
    RSA* rsa_key = PEM_read_bio_RSA_PUBKEY(bio, NULL, NULL, NULL);
    if (rsa_key == nullptr) {
        BIO_free(bio);
        return false;
    }

    EVP_MD_CTX *md_ctx = EVP_MD_CTX_create();
    if (md_ctx == nullptr) {
        BIO_free(bio);
        RSA_free(rsa_key);
        return false;
    }

    bool ret = true;
    if (EVP_DigestInit(md_ctx, EVP_sha256()) <= 0
        ||EVP_DigestUpdate(md_ctx, (const void *)message.c_str(), message.length())<= 0
            ||EVP_DigestFinal(md_ctx, digest, &digest_len)<= 0
                ||RSA_verify(NID_sha256, digest, digest_len, (const unsigned char*)signature.c_str(), signature_len, rsa_key)<= 0) {
        ret = false;
    }

    BIO_free(bio);
    RSA_free(rsa_key);
    EVP_MD_CTX_destroy(md_ctx);

    return ret;
}


bool SignAlgorithm::hmac_sha256_sign_openssl(const std::string &message, const std::string &key, std::string &signature) {
    
    char signature_buf[EVP_MAX_MD_SIZE];
    unsigned int signature_len = 0;

    HMAC_CTX *hmac = HMAC_CTX_new();
    if (hmac == nullptr) {
        return false;
    }

    std::string real_key;
    base64_decode(key, real_key);

    bool ret = true;
    if (HMAC_Init_ex(hmac, real_key.c_str(), real_key.length(), EVP_sha256(), NULL) <= 0
        ||HMAC_Update(hmac, ( unsigned char* )message.c_str(), message.length()) <= 0
            ||HMAC_Final(hmac, (unsigned char*)signature_buf, &signature_len) <= 0 ) {
        ret = false;
    }
    
    if (ret == true) {
        std::string final_sig(signature_buf, signature_len);
        signature = final_sig;
    }

    HMAC_CTX_free(hmac);
    return ret;
    
}
bool SignAlgorithm::hmac_sha256_verify_openssl(const std::string &message, const std::string &key, std::string &signature) {
    std::string compute_signature;
    if (!hmac_sha256_sign_openssl(message, key, compute_signature)) {
        return false;
    }
    if (compute_signature.compare(signature) == 0) {
        return true;
    }
    return false;
}

6 log库

6.1 boost log使用版本

使用boost的log实现的方法,简单来说基本不需要做什么操作,就是初始化完成就可以调用记录日志的函数了。需要注意的是这个是同步的sink

/* .h文件 */
#ifndef __BOOST_BASE_LOG__
#define __BOOST_BASE_LOG__

#include <iostream>

#include <boost/log/core.hpp>
#include <boost/log/trivial.hpp>
#include <boost/log/expressions.hpp>
#include <boost/log/sinks/text_file_backend.hpp>
#include <boost/log/utility/setup/file.hpp>
#include <boost/log/utility/setup/common_attributes.hpp>
#include <boost/log/sources/severity_logger.hpp>
#include <boost/log/sources/record_ostream.hpp>


void init_boost_base_log();

#endif


/* .cpp文件 */
#include "boost_base_log.h"

void init_boost_base_log() {
    boost::log::add_file_log(
        /* 具体记录日志的名称 */
        boost::log::keywords::file_name = "sample_%N.log",
        /* 文件10MB的时候rotate一次 */
        boost::log::keywords::rotation_size = 10 * 1024 * 1024,
        /* 或者每天午夜的时候rotate一次 */
        boost::log::keywords::time_based_rotation = boost::log::sinks::file::rotation_at_time_point(0, 0, 0),
        /* 消息格式 */
        boost::log::keywords::format = "[%TimeStamp%]: %Message%"
    );
    /* 添加附加属性比方说时间戳,行号,线程号等 */
    boost::log::add_common_attributes();
}

/* 具体用法 */
BOOST_LOG_TRIVIAL(trace) << "log message";
BOOST_LOG_TRIVIAL(debug) << "log message";
BOOST_LOG_TRIVIAL(info) << "log message";
BOOST_LOG_TRIVIAL(warning) << "log message";
BOOST_LOG_TRIVIAL(error) << "log message";
BOOST_LOG_TRIVIAL(fatal) << "log message"; 

/* 编译链接的命令 
 * 如果使用cmake,需要加一条add_definitions(-DBOOST_LOG_DYN_LINK)
 * 然后链接的命令为
 * target_link_libraries(boost_log_test -lboost_log-mt -lboost_log_setup-mt -lboost_thread-mt)
 */

6.2 spdlog

#include <memory>
#include <thread>
#include <iostream>


#include "spdlog/spdlog.h"
#include "spdlog/sinks/rotating_file_sink.h"
#include "spdlog/sinks/daily_file_sink.h"
#include "spdlog/sinks/stdout_color_sinks.h"

#if 0
#define LOG_DEBUG(...) SPDLOG_LOGGER_DEBUG(spdlog::default_logger_raw(), __VA_ARGS__);SPDLOG_LOGGER_DEBUG(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_INFO(...) SPDLOG_LOGGER_INFO(spdlog::default_logger_raw(), __VA_ARGS__);SPDLOG_LOGGER_INFO(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_WARN(...) SPDLOG_LOGGER_WARN(spdlog::default_logger_raw(), __VA_ARGS__);SPDLOG_LOGGER_WARN(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_ERROR(...) SPDLOG_LOGGER_ERROR(spdlog::default_logger_raw(), __VA_ARGS__);SPDLOG_LOGGER_ERROR(spdlog::get("file_log"), __VA_ARGS__)
#else
#define LOG_DEBUG(...) SPDLOG_LOGGER_DEBUG(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_INFO(...) SPDLOG_LOGGER_INFO(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_WARN(...) SPDLOG_LOGGER_WARN(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_ERROR(...) SPDLOG_LOGGER_ERROR(spdlog::get("file_log"), __VA_ARGS__)

#endif
std::shared_ptr<spdlog::logger> file_logger;

void init_log() {
    file_logger = spdlog::rotating_logger_mt("file_log", "log", 1024 * 1024 * 100, 3);
    //auto logger = spdlog::daily_logger_mt("daily_logger", "logs/daily.txt", 2, 30);
    /* 遇到warn,冲洗一次*/
    //file_logger->flush_on(spdlog::level::warn);
    /* 三秒冲洗一次数据 */
    spdlog::flush_every(std::chrono::seconds(3));
    spdlog::set_pattern("%Y-%m-%d %H:%M:%S [%l] [%t] - <%s>|<%#>|<%!>,%v");

}

void stop_log() {
    spdlog::drop_all();
}

7 varint编解码

在实现内部工程的时候遇到一个问题,就是varint编解码的实现,借鉴了网上开源代码,并修复了原本代码只适合unsigned int类型数据bug,并且添加了一个解码时获得varint长度的函数。但是需要注意一点,就是varint解码时,具体解码的对象是否合法需要自己判断,检查的方法也非常简单,无非就是检查解码varint的时候是不少超过5字节,解码varlong是不是超过10字节,或者到边界了,最后一字节第一个位为1,显示还有下个字节。代码差不多就是

    int varint_size = 1;
    for (int i = 0 ; i < length; ++i) {
        if (input[i] & 0x80) {
            varint_size += 1;
            if (varint_size > 5) {
                return false;
            }
        } else {
            break;
        }  
    }

    /* 能走到这里说明没超出5的限制,但是也可能是序列到头,即最后一个字符最高位不是0,也不合法 */
    if (varint_size == length + 1) {
        return false;
    }

头文件

#ifndef __VARINT_HPP__
#define __VARINT_HPP__

#include <iostream>

std::size_t varintSize(std::size_t value);

/**
 * Encodes an unsigned variable-length integer using the MSB algorithm.
 * This function assumes that the value is stored as little endian.
 * @param value The input value. Any standard integer type is allowed.
 * @param output A pointer to a piece of reserved memory. Must have a minimum size dependent on the input size (32 bit = 5 bytes, 64 bit = 10 bytes).
 * @return The number of bytes used in the output memory.
 */
template<typename int_t = uint64_t>
std::size_t encodeVarint(int_t value, uint8_t* output) {
    size_t outputSize = 0;
    //While more than 7 bits of data are left, occupy the last output byte
    // and set the next byte flag
    while (value > 127) {
        //|128: Set the next byte flag
        output[outputSize] = ((uint8_t)(value & 127)) | 128;
        //Remove the seven bits we just wrote
        value >>= 7;
        outputSize++;
    }
    output[outputSize++] = ((uint8_t)value) & 127;
    return outputSize;
}


/**
 * Decodes an unsigned variable-length integer using the MSB algorithm.
 * @param value A variable-length encoded integer of arbitrary size.
 * @param inputSize How many bytes are 
 */
template<typename int_t = uint64_t>
int_t decodeVarint(uint8_t* input, std::size_t inputSize) {
    int_t ret = 0;
    for (size_t i = 0; i < inputSize; i++) {
        ret |= (static_cast<size_t>(input[i] & 127)) << (7 * i);
        //If the next-byte flag is set
        if(!(input[i] & 128)) {
            break;
        }
    }
    return ret;
}

#endif 

cpp文件

#include "varint.hpp"


std::size_t varintSize(std::size_t value) {
    std::size_t n = 1;
    while(value > 127)
    {
        ++n;
        value /= 128;
    }
    return n;
}

线程池

先来点基础知识:

C++ FUTURE提供了下面这些类型:

  • Providers 类:std::promise, std::package_task
  • Futures 类:std::future, shared_future.
  • Providers 函数:std::async()
  • 其他类型:std::future_error, std::future_errc, std::future_status, std::launch.

std::future && std::packaged_task

什么是std::future? std::future 可以用来获取异步任务的结果,因此可以把它当成一种简单的线程间同步的手段。std::future 通常由某个 Provider 创建,你可以把 Provider 想象成一个异步任务的提供者,Provider 在某个线程中设置共享状态的值,与该共享状态相关联的 std::future 对象调用 get(通常在另外一个线程中) 获取该值,如果共享状态的标志不为 ready,则调用 std::future::get 会阻塞当前的调用者,直到 Provider 设置了共享状态的值(此时共享状态的标志变为 ready),std::future::get 返回异步任务的值或异常(如果发生了异常)。

什么是std::packaged_task? std::packaged_task 包装一个可调用的对象,并且允许异步获取该可调用对象产生的结果,std::packaged_task 与 std::function 类似,只不过 std::packaged_task 将其包装的可调用对象的执行结果通过task.get_future传递给一个 std::future 对象(该对象通常在另外一个线程中获取 std::packaged_task 任务的执行结果)。

有了问题就可以用线程池和异步执行,

#include <memory>
#include <utility>

#include "thread_pool.h"

namespace async {

// Schedule a task to the given thread pool. If thread_pool is null, run the
// task synchronously on the current thread.
template <typename Func, typename... Args>
auto ScheduleFuture(ThreadPool* thread_pool, Func&& f, Args&&... args) {
  if (thread_pool == nullptr) {
    Future<typename std::result_of<Func(Args...)>::type> future(
        std::async(std::launch::deferred, std::forward<Func>(f),
                   std::forward<Args>(args)...));
    future.Wait();
    return future;
  }
  return thread_pool->Schedule(std::forward<Func>(f),
                               std::forward<Args>(args)...);
}

// Schedule a task to the default thread pool.
template <typename Func, typename... Args>
auto ScheduleFuture(Func&& f, Args&&... args) {
  return ScheduleFuture(ThreadPool::DefaultPool(), std::forward<Func>(f),
                        std::forward<Args>(args)...);
}

// Asynchronously destroy a container.
template <typename ContainerT>
void DestroyContainerAsync(ThreadPool* thread_pool, ContainerT container) {
  ScheduleFuture(thread_pool, [_ = std::move(container)]() mutable {
    const auto unused = std::move(_);
  });
}

// Same, but use the default thread pool.
template <typename ContainerT>
void DestroyContainerAsync(ContainerT container) {
  DestroyContainerAsync(ThreadPool::DisposalPool(), std::move(container));
}

template <typename T>
void WaitForFuture(const Future<T>& future) {
  future.Wait();
}

}  // namespace async

下面看线程池的代码,

class ThreadPool {
 public:
  explicit ThreadPool(int num_workers,
                      const std::function<void(int index)>& init_thread = {});
  ~ThreadPool();

  // Return the default thread pool.
  static ThreadPool* DefaultPool();

  // Return the disposal thread pool to destroy stuff asynchronously.
  static ThreadPool* DisposalPool();

  ThreadPool(const ThreadPool&) = delete;
  ThreadPool& operator=(const ThreadPool&) = delete;

  int NumWorkers() const { return workers_.size(); }

  template <class Func, class... Args>
  using FutureType = Future<typename std::result_of<Func(Args...)>::type>;

  // Schedule a new task.
  template <class Func, class... Args>
  FutureType<Func, Args...> Schedule(Func&& f, Args&&... args)
      LOCKS_EXCLUDED(mutex_);

 private:
  absl::Mutex mutex_;
  absl::CondVar cond_var_ GUARDED_BY(mutex_);

  std::vector<std::thread> workers_;
  // A thread safe queue protected by condition_ and mutex_.
  std::queue<std::function<void()>> tasks_ GUARDED_BY(mutex_);
  bool stop_requested_ GUARDED_BY(mutex_) = false;
};


template <class Func, class... Args>
ThreadPool::FutureType<Func, Args...> ThreadPool::Schedule(Func&& f,
                                                           Args&&... args) {
  using ReturnType = typename std::result_of<Func(Args...)>::type;
  const auto task = std::make_shared<std::packaged_task<ReturnType()>>(
      std::bind(std::forward<Func>(f), std::forward<Args>(args)...));
  Future<ReturnType> res(task->get_future());

  // If there is no worker, this is an inline thread pool, and the task will be
  // immediately run on the current thread.
  int64_t tasks_size = -1;
  if (workers_.empty()) {
    (*task)();
  } else {
    absl::MutexLock lock(&mutex_);
    // QCHECK(!stop_requested_) << "The thread pool has been stopped";
    tasks_.emplace([task]() { (*task)(); });
    tasks_size = tasks_.size();
    cond_var_.Signal();
  }
  if (tasks_size >= 0) {
    QCOUNTER("schedulefuture_callback_size", tasks_size);
  }
  return res;
}

}


#include "thread_pool.h"

ThreadPool::ThreadPool(int num_workers,
                       const std::function<void(int index)>& init_thread) {
  for (int index = 0; index < num_workers; ++index) {
    workers_.emplace_back([this, index, init_thread] {
      if (init_thread) {
        init_thread(index);
      }

      while (true) {
        std::function<void()> task;
        {
          absl::MutexLock lock(&mutex_);

          while (!stop_requested_ && tasks_.empty()) {
            cond_var_.Wait(&mutex_);
          }
          if (stop_requested_ && tasks_.empty()) {
            return;
          }
          task = std::move(tasks_.front());
          tasks_.pop();
        }
        task();
      }
    });
  }
}

ThreadPool::~ThreadPool() {
  {
    absl::MutexLock lock(&mutex_);
    stop_requested_ = true;
    cond_var_.SignalAll();
  }
  for (std::thread& worker : workers_) {
    worker.join();
  }
}

// static
ThreadPool* ThreadPool::DefaultPool() {
  static ThreadPool* default_pool = new ThreadPool(FLAGS_default_pool_size);
  return default_pool;
}

// static
ThreadPool* ThreadPool::DisposalPool() {
  static ThreadPool* disposal_pool = new ThreadPool(1);
  return disposal_pool;
}

结尾

唉,尴尬

狗头的赞赏码.jpg