C++开发工具库
1 并发多线程map
坦白讲核心思想都是一样子的,多个bucket,然后每个bucket下面挂一个红黑树(实际上就是stl的map),红黑树使用boost的读写锁来保护,源代码里面我禁用了拷贝构造和赋值运算符,因为RA2(锁)不可拷贝是一个通识,如果想拷贝那么就注释调宏标记的地方,但是尽量建议使用shared_ptr来做这个事情。
#ifndef __CONCURRENT_MAP_H__
#define __CONCURRENT_MAP_H__
#include <map>
#include <vector>
#include <atomic>
#include <iostream>
#include <boost/thread/locks.hpp>
#include <boost/thread/shared_mutex.hpp>
/* 使用此代码时请注意,直接用.h文件,不要写.cpp,使用置入式模型*/
using std::map;
using namespace boost;
/* bucket的代码没什么难度,只需要注意一点,更新和插入的区别。 更新是原先必须有这个值 */
/* 因为每次查询更新数据都只会锁住一个元素,不会锁住多个元素,所以不会产生死锁 */
template<class KEY_T, class VALUE_T, class Compare>
class ConcurrentBucket {
public:
ConcurrentBucket() {
}
virtual ~ConcurrentBucket() {
write_lock lock(rwlock_);
map_.clear();
}
bool Lookup(const KEY_T &key, VALUE_T &value) {
bool ret = false;
/* 锁单独的生命周期要短,最好不要和原子变量混在一起 */
read_lock lock(rwlock_);
typename std::map<KEY_T, VALUE_T>::iterator find = map_.find(key);
if (find != map_.end()) {
value = (*find).second; /* 从生成的代码来看,编译器会生成一个默认的赋值运算符函数,每个对象都会执行拷贝 */
ret = true;
}
return ret;
}
uint64_t Size() {
read_lock lock(rwlock_);
return map_.size();
}
void Clear() {
write_lock lock(rwlock_);
map_.clear();
}
bool Contain(const KEY_T &key) {
bool ret = false;
/* 锁单独的生命周期要短,最好不要和原子变量混在一起 */
read_lock lock(rwlock_);
typename std::map<KEY_T, VALUE_T>::iterator find = map_.find(key);
if (find != map_.end()) {
ret = true;
}
return ret;
}
bool Insert(const KEY_T &key, const VALUE_T &value) {
bool ret = false;
{
write_lock lock(rwlock_);
ret = InsertWithoutLock(key, value);
}
return ret;
}
bool Update(const KEY_T &key, const VALUE_T &value) {
bool ret = false;
{
write_lock lock(rwlock_);
ret = UpdateWithoutLock(key, value);
}
return ret;
}
void Remove(const KEY_T &key) {
{
write_lock lock(rwlock_);
RemoveWithoutLock(key);
}
}
void GetAllKey(std::vector<std::pair<KEY_T, VALUE_T> > &list) {
read_lock lock(rwlock_);
typename std::map<KEY_T, VALUE_T>::iterator iter = map_.begin();
for ( ; iter != map_.end(); ++iter) {
list.push_back(std::pair<KEY_T, VALUE_T>(iter->first, iter->second));
}
return;
}
void UpdateKeyBatch(std::vector<std::pair<KEY_T, VALUE_T> > &list) {
{
write_lock lock(rwlock_);
typename std::vector<std::pair<KEY_T, VALUE_T> >::iterator iter = list.begin();
for (; iter != list.end(); ++iter) {
UpdateWithoutLock(iter->first, iter->second);
}
}
return;
}
void InsertKeyBatch(std::vector<std::pair<KEY_T, VALUE_T> > &list) {
{
write_lock lock(rwlock_);
typename std::vector<std::pair<KEY_T, VALUE_T> >::iterator iter = list.begin();
for (; iter != list.end(); ++iter) {
InsertWithoutLock(iter->first, iter->second);
}
}
return;
}
void RemoveKeyBatch(std::vector<KEY_T> &list) {
{
write_lock lock(rwlock_);
typename std::vector<KEY_T>::iterator iter = list.begin();
for (; iter != list.end(); ++iter) {
RemoveWithoutLock((*iter));
}
}
return;
}
void Echo() const {
std::cout << " oh ho" << std::endl;
}
private:
void RemoveWithoutLock(const KEY_T &key) {
typename std::map<KEY_T, VALUE_T>::iterator find = map_.find(key);
if (find != map_.end()) {
map_.erase(find);
}
return;
}
bool InsertWithoutLock(const KEY_T &key, const VALUE_T &value) {
bool ret = false;
typename std::map<KEY_T, VALUE_T>::iterator find = map_.find(key);
if (find == map_.end()) {
map_.insert(std::pair<KEY_T, VALUE_T>(key, value));
ret = true;
} else {
if (Compare()(value, find->second)) {
map_.erase(find);
map_.insert(std::pair<KEY_T, VALUE_T>(key, value));
ret = true;
}
}
return ret;
}
bool UpdateWithoutLock(const KEY_T &key, const VALUE_T &value) {
bool ret = false;
typename std::map<KEY_T, VALUE_T>::iterator find = map_.find(key);
if (find != map_.end()) {
if (Compare()(value, find->second)) {
map_.erase(find);
map_.insert(std::pair<KEY_T, VALUE_T>(key, value));
ret = true;
}
}
return ret;
}
private:
typedef boost::shared_lock<boost::shared_mutex> read_lock;
typedef boost::unique_lock<boost::shared_mutex> write_lock;
std::map<KEY_T, VALUE_T> map_;
/* using boost shared_mutex, boost version 1.6.9 */
boost::shared_mutex rwlock_;
};
template<typename KEY_T, typename VALUE_T, typename Compare, typename Hash=std::hash<KEY_T> >
class ConcurrentMap {
public:
ConcurrentMap(uint64_t bucket_nums = 61, Hash const& hasher = Hash()) :
buckets_(bucket_nums),
hasher_ (hasher) {
for (uint64_t i = 0; i < bucket_nums; ++i) {
buckets_[i].reset(new ConcurrentBucket<KEY_T, VALUE_T, Compare>);
}
}
bool LookUp(const KEY_T &key, VALUE_T &value) {
return GetBucket(key).Lookup(key, value);
}
bool Contain(const KEY_T &key) {
return GetBucket(key).Contain(key);
}
bool Insert(const KEY_T &key, const VALUE_T &value) {
return GetBucket(key).Insert(key, value);
}
bool Update(const KEY_T &key, const VALUE_T &value) {
return GetBucket(key).Update(key, value);
}
bool Delete(const KEY_T &key) {
GetBucket(key).Remove(key);
return true;
}
void GetAllKey(std::vector<std::pair<KEY_T, VALUE_T> > &list) {
list.clear();
#if 0
//auto iter = buckets_.begin();
//这里要注意,先得对iter做解引用,得到uniqueptr,然后对unique_ptr做解引用或者用->才能进行调用
typename std::vector<std::unique_ptr<ConcurrentBucket<KEY_T, VALUE_T, Compare> > >::iterator iter = buckets_.begin();
for ( ; iter != buckets_.end(); ++iter ) {
(*iter)->GetAllKey(list);
}
#endif
for (size_t bucket_index = 0; bucket_index < buckets_.size(); ++bucket_index) {
buckets_[bucket_index]->GetAllKey(list);
}
return;
}
void UpdateKeyBatch(std::vector<std::pair<KEY_T, VALUE_T> > &list) {
/* 更新数组的个数需要和桶的个数一致 */
std::vector<std::vector<std::pair<KEY_T, VALUE_T> > > update_lists(buckets_.size());
/* 将更新的元素丢入数组之中 */
typename std::vector<std::pair<KEY_T, VALUE_T> >::iterator iter = list.begin();
for (; iter != list.end(); ++iter) {
std::size_t bucket_index = hasher_(iter->first) % buckets_.size();
update_lists[bucket_index].push_back(std::pair<KEY_T, VALUE_T>(iter->first, iter->second));
}
/* 将对应的元素更新到对应的bucket里面 */
for (size_t bucket_index = 0; bucket_index < buckets_.size(); ++bucket_index) {
buckets_[bucket_index]->UpdateKeyBatch(update_lists[bucket_index]);
}
}
void RemoveKeyBatch(std::vector<KEY_T> &list) {
/* 删除数组的个数需要和桶的个数一致 */
std::vector<std::vector<KEY_T> > remove_lists(buckets_.size());
/* 将删除的元素丢入数组之中 */
typename std::vector<KEY_T>::iterator iter = list.begin();
for (; iter != list.end(); ++iter) {
std::size_t bucket_index = hasher_(*iter) % buckets_.size();
remove_lists[bucket_index].push_back(*iter);
}
/* 将对应的元素更新到对应的bucket里面 */
for (size_t bucket_index = 0; bucket_index < buckets_.size(); ++bucket_index) {
buckets_[bucket_index]->RemoveKeyBatch(remove_lists[bucket_index]);
}
}
void InsertKeyBatch(std::vector<std::pair<KEY_T, VALUE_T> > &list) {
/* 更新数组的个数需要和桶的个数一致 */
std::vector<std::vector<std::pair<KEY_T, VALUE_T> > > Insert_lists(buckets_.size());
/* 将更新的元素丢入数组之中 */
typename std::vector<std::pair<KEY_T, VALUE_T> >::iterator iter = list.begin();
for (; iter != list.end(); ++iter) {
std::size_t bucket_index = hasher_(iter->first) % buckets_.size();
Insert_lists[bucket_index].push_back(std::pair<KEY_T, VALUE_T>(iter->first, iter->second));
}
/* 将对应的元素更新到对应的bucket里面 */
for (size_t bucket_index = 0; bucket_index < buckets_.size(); ++bucket_index) {
buckets_[bucket_index]->InsertKeyBatch(Insert_lists[bucket_index]);
}
}
ConcurrentBucket<KEY_T, VALUE_T, Compare>& GetBucket(KEY_T const& key) const {
std::size_t const bucket_index = hasher_(key)% buckets_.size();
return *buckets_[bucket_index];
}
uint64_t Size() {
uint64_t total_size = 0;
for (size_t i = 0 ; i < buckets_.size(); ++i) {
total_size += buckets_[i]->Size();
}
return total_size;
}
void Clear() {
for (size_t i = 0 ; i < buckets_.size(); ++i) {
buckets_[i]->Clear();
}
}
/* 禁用这两者来保证绝对的安全,主要是本身mutex就是禁止拷贝的*/
/* 禁用拷贝构造 */
ConcurrentMap(ConcurrentMap const &other) = delete;
/* 禁用赋值运算符 */
ConcurrentMap& operator=(ConcurrentMap const &other) = delete;
#if 0
/* 线程安全的拷贝构造函数 */
ConcurrentMap(ConcurrentMap const &other) :
buckets_(other.buckets_.size()),
hasher_(other.hasher_) {
for (uint64_t i = 0; i < buckets_.size(); ++i) {
buckets_[i].reset(new ConcurrentBucket<KEY_T, VALUE_T, Compare>);
}
std::vector<std::pair<KEY_T, VALUE_T> > list;
other.GetAllKey(list);
InsertKeyBatch(list);
}
ConcurrentMap& operator=(ConcurrentMap const &other) {
/* 务必保证两个bucket的大小一致,不提供可伸缩的bucket*/
assert(buckets_.size() == other.buckets_.size());
hasher_ = other.hasher_;
for (uint64_t i = 0; i < buckets_.size(); ++i) {
buckets_[i].reset(new ConcurrentBucket<KEY_T, VALUE_T, Compare>);
}
std::vector<std::pair<KEY_T, VALUE_T> > list;
other.GetAllKey(list);
InsertKeyBatch(list);
}
#endif
private:
/* 使用unique_ptr保证内存安全 */
std::vector<std::unique_ptr<ConcurrentBucket<KEY_T, VALUE_T, Compare> > > buckets_;
Hash hasher_;
};
#endif
2 并发LRU实现
实际上并发LRU的实现和并发MAP是非常相似的,并发MAP是挂了一堆的BUCKET,而并发LRU实际上也是挂了一堆的LRUBucket,然后每个找到对应的LRU。但是这个严格说,并不是完全的LRU,因为它拆分了几个不同的链表出来。我目前在用的时候使用的还是最简单的LRU,性能比较低但是严格。
#ifndef __LRU_H__
#define __LRU_H__
#include <boost/thread/locks.hpp>
#include <boost/thread/shared_mutex.hpp>
#include <map>
#include <list>
#include <vector>
#include <string>
#include <unordered_map>
template<typename KEY_T, typename VALUE_T>
class LRU {
public:
LRU(uint64_t capacity = 31) :
capacity_(capacity),
size_(0) {
}
void Put(const KEY_T &key, const VALUE_T &value) {
write_lock lock(rwlock_);
auto iter = map_.find(key);
if (iter != map_.end()) {
touch(iter);
} else {
if (map_.size() == capacity_) {
map_.erase(list_.back());
list_.pop_back();
}
list_.push_front(key);
}
map_[key] = {value, list_.begin()};
}
bool Get(const KEY_T &key, VALUE_T &value) {
write_lock lock(rwlock_);
auto iter = cache_.find(key);
if (iter == cache_.end()) {
return false;
}
touch(iter);
value = iter->second.first
return true;
}
/* 没有的话返回一个新的VALUE_T回去,此时缓存是没有的 */
/* 如果有的话放到最开始的地方 */
VALUE_T Get(const KEY_T &key) {
write_lock lock(rwlock_);
auto iter = cache_.find(key);
if (iter == cache_.end()) {
return VALUE_T();
}
touch(iter);
return iter->second.first;
}
bool GetAllKey(std::vector<KEY_T, VALUE_T> &res) {
read_lock lock(rwlock_);
for (auto iter : map_) {
res.push_back(std::pair<KEY_T, VALUE_T> (iter.first, iter.second.second));
}
return true;
}
private:
void touch(typename std::map<KEY_T, std::pair<VALUE_T, typename std::list<KEY_T>::iterator> >::iterator iter) {
KEY_T key = iter->first;
list_.erase(iter->second.second);
list_.push_front(key);
iter->second.second = list_.begin();
}
private:
typedef boost::shared_lock<boost::shared_mutex> read_lock;
typedef boost::unique_lock<boost::shared_mutex> write_lock;
uint64_t capacity_;
uint64_t size_;
boost::shared_mutex rwlock_;
/* iterator前面需要加一个typename */
std::list<KEY_T> list_;
std::map<KEY_T, std::pair<VALUE_T, typename std::list<KEY_T>::iterator> > map_;
};
template<typename KEY_T, typename VALUE_T, typename Hash=std::hash<KEY_T> >
class ConcurrentLRU {
public:
ConcurrentLRU(uint64_t bucket_nums = 11, Hash const& hasher = Hash()) :
buckets_(bucket_nums),
hasher_ (hasher) {
for (uint64_t i = 0; i < bucket_nums; ++i) {
buckets_[i].reset(new LRU<KEY_T, VALUE_T>);
}
}
void Put(const KEY_T &key, const VALUE_T &value) {
GetBucket(key).Put(key, value);
return;
}
bool Get(const KEY_T &key, VALUE_T &value) {
return GetBucket(key).Get(key, value);
}
/* 没有的话返回一个新的VALUE_T回去,此时缓存是没有的 */
/* 如果有的话放到最开始的地方 */
VALUE_T Get(const KEY_T &key) {
return GetBucket(key).Get(key);
}
void GetAllKey(std::vector<std::pair<KEY_T, VALUE_T> > &list) {
list.clear();
#if 0
//auto iter = buckets_.begin();
//这里要注意,先得对iter做解引用,得到uniqueptr,然后对unique_ptr做解引用或者用->才能进行调用
typename std::vector<std::unique_ptr<LRU<KEY_T, VALUE_T> > >::iterator iter = buckets_.begin();
for ( ; iter != buckets_.end(); ++iter ) {
(*iter)->GetAllKey(list);
}
#endif
for (size_t bucket_index = 0; bucket_index < buckets_.size(); ++bucket_index) {
buckets_[bucket_index]->GetAllKey(list);
}
return;
}
LRU<KEY_T, VALUE_T>& GetBucket(KEY_T const& key) const {
std::size_t const bucket_index = hasher_(key)% buckets_.size();
return *buckets_[bucket_index];
}
/* 禁用拷贝构造 */
ConcurrentLRU(ConcurrentLRU const &other) = delete;
/* 禁用赋值运算符 */
ConcurrentLRU& operator=(ConcurrentLRU const &other) = delete;
private:
/* 使用unique_ptr保证内存安全 */
std::vector<std::unique_ptr<LRU<KEY_T, VALUE_T> > > buckets_;
Hash hasher_;
};
#endif
3 一个线程安全的连接池
下面这个连接池在写的时候由于需要证书更新,因此需要使用TLSSocketWrapper来包裹证书/密钥/ca证书。这里要注意为了实现每次握手的时候只使用新的证书,所以这里会加一个证书的时间戳,来保证用的都是新的证书。可以发现这样子即使证书握手失败也依然会拿新的证书来执行握手。
该连接池首先执行tls握手,然后将tls握手的socket(boost::shared_ptr<ssl::stream<ip::tcp::socket> >
)丢到LRU中,来保证lru总是缓存了最新的值,然后发送数据失败的情况下会重新建立连接。为了保活使用了http get对连接做get来探活,每15秒执行一次探活 。
头文件:
#ifndef __SYNC_TLS_CLIENT_H__
#define __SYNC_TLS_CLIENT_H__
#include <iostream>
#include <string>
#include "lru.h"
#include "singleton.h"
#include "repeated_timer.h"
#include <string>
#include <istream>
#include <ostream>
#include <cstdlib>
#include <string>
#include <atomic>
#include <boost/asio.hpp>
#include <boost/asio/ssl.hpp>
#include "boost/thread.hpp"
#include "boost/function.hpp"
#include "boost/shared_ptr.hpp"
#include "boost/move/unique_ptr.hpp"
#include "boost/make_shared.hpp"
#include "boost/lockfree/spsc_queue.hpp"
#include <boost/beast/core.hpp>
#include <boost/beast/http.hpp>
#include <boost/beast/version.hpp>
#include <boost/asio/connect.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/ssl/error.hpp>
#include <boost/asio/ssl/stream.hpp>
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
namespace ssl = boost::asio::ssl; // from <boost/asio/ssl.hpp>
namespace http = boost::beast::http; // from <boost/beast/http.hpp>
using namespace boost::asio;
#define HTTP_11_VERSION 11
#define DEFAULT_HTTP_ALIVE_CONFIG "timeout=60, max=1000"
#define CONTENT_TYPE_JSON "application/json; charset=utf-8"
#define CONTENT_TYPE_PLAIN "text/plain; charset=utf-8"
#define CONTENT_TYPE_HTML "text/html; charset=utf-8"
#define CONNECT_MAX_TRY_TIMES 4
class TLSSocketWrapper {
public:
/* 连接函数被封装到了TLSSocketWrapper的构造函数里面 */
TLSSocketWrapper(const std::string &ip, const std::string &port, boost::asio::io_context &global_io_context, const std::string &cert, const std::string &key, const std::string &ca_cert, uint64_t time_stamp);
std::string GetCert() const;
std::string GetKey() const;
std::string GetCaCert() const;
uint64_t GetTimeStamp() const;
bool HandshakeFinished() const;
/* 下面两个不是线程安全的 */
int GetMessageCore(const std::string &server_path, std::string &header, std::string &body);
int PostMessageCore(const std::string &server_path, const std::string &request_content_type, const std::string &request_body, std::string &header, std::string &body);
private:
typedef boost::shared_lock<boost::shared_mutex> read_lock;
typedef boost::unique_lock<boost::shared_mutex> write_lock;
boost::unique_ptr<ssl::stream<ip::tcp::socket> > real_socket_;
boost::shared_mutex rwlock_;
std::string ip_;
std::string port_;
std::string cert_;
std::string key_;
std::string ca_cert_;
uint64_t time_stamp_;
};
class TlsWrapperCompare {
public:
bool operator() (const boost::shared_ptr<TLSSocketWrapper> &a, const boost::shared_ptr<TLSSocketWrapper> &b) {
return a->GetTimeStamp() >= b->GetTimeStamp();
}
};
class SyncTlsConnectionManager : public hxn::Singleton<SyncTlsConnectionManager> {
public:
/* 目前,该实例负责维护数据并保活 */
void Init();
void HttpHeartBeat();
io_context& GetGlobalIOContext();
ConcurrentLRU<std::string, boost::shared_ptr<TLSSocketWrapper>, TlsWrapperCompare>& GetGlobalSocketWrapper();
private:
/* 第一个io context是给所有的tls连接用的 */
boost::asio::io_context global_io_context_;
ConcurrentLRU<std::string, boost::shared_ptr<TLSSocketWrapper>, TlsWrapperCompare> global_sockets_;
/* 异步定时器+线程, 使用boost asio + chrono实现*/
boost::shared_ptr<boost::thread> loop_;
boost::shared_ptr<RepeatedTimer> timer_;
/* 该context负责给全局定时器使用 */
boost::shared_ptr<io_service> io_service_;
};
/* asio socket的同步读写是线程安全的 */
class SyncTlsClient {
public:
/* 探活/保活接口专用构造函数 */
SyncTlsClient(const std::string &ip, const std::string &port, const boost::shared_ptr<TLSSocketWrapper> &connection);
/* 时间戳为证书时间戳,或者说证书文件的时间戳 */
SyncTlsClient(const std::string &ip, const std::string &port, const std::string &cert, const std::string &key, const std::string &ca_cert, const uint64_t &time_stamp);
/* http心跳报文实际上是个get请求 */
int HeartBeat(const std::string &server_path, std::string &response_header, std::string &response_body);
int GetMessage(const std::string &server_path, std::string &response_header, std::string &response_body);
int PostMessage(const std::string &server_path, const std::string &request_content_type, const std::string &request_body, std::string &response_header, std::string &response_body);
private:
boost::shared_ptr<TLSSocketWrapper> socket_wrapper_;
/* 可以是ip/port也可能是host/port */
std::string ip_;
std::string port_;
};
#endif
cpp文件
#include "sync_tls_client.h"
/* 构造函数只会在每个线程构造所以必然是线程安全的 */
TLSSocketWrapper::TLSSocketWrapper(const std::string &ip, const std::string &port, boost::asio::io_context &global_io_context, const std::string &cert, const std::string &key, const std::string &ca_cert, uint64_t time_stamp) :
ip_(ip),
port_(port),
cert_(cert),
key_(key),
ca_cert_(ca_cert),
time_stamp_(time_stamp) {
int try_times = 0;
while ((real_socket_ == nullptr) && (try_times < CONNECT_MAX_TRY_TIMES)) {
try {
ssl::context ssl_context{ssl::context::sslv23_client};
/* 两者都不为空, 才执行load */
if (!(cert.empty() || key.empty())) {
ssl_context.use_certificate_chain(buffer(cert.data(), cert.size()));
ssl_context.use_private_key(buffer(key.data(), key.size()), ssl::context::file_format::pem);
}
if (!ca_cert.empty()) {
ssl_context.add_certificate_authority(buffer(ca_cert.data(), ca_cert.size()));
}
real_socket_.reset(new ssl::stream<tcp::socket>(global_io_context, ssl_context));
boost::system::error_code ip_invalid;
ip::address ip_address = ip::address::from_string(ip, ip_invalid);
if (ip_invalid) {
/* url */
tcp::resolver resolver{global_io_context};
auto const results = resolver.resolve(ip, port);
connect(real_socket_->next_layer(), results.begin(), results.end());
} else {
tcp::endpoint endpoint(ip_address, std::stoi(port));
real_socket_->next_layer().connect(endpoint);
}
/* keep下层socket alive */
real_socket_->next_layer().set_option(socket_base::keep_alive(true));
real_socket_->handshake(ssl::stream_base::client);
} catch(std::exception const& e) {
std::cerr << "Handshake Error: " << e.what() << std::endl;
real_socket_ = nullptr;
}
++try_times;
}
}
std::string TLSSocketWrapper::GetCert() const {
return cert_;
}
std::string TLSSocketWrapper::GetKey() const {
return key_;
}
uint64_t TLSSocketWrapper::GetTimeStamp() const {
return time_stamp_;
}
bool TLSSocketWrapper::HandshakeFinished() const {
return real_socket_ != nullptr;
}
std::string TLSSocketWrapper::GetCaCert() const {
return ca_cert_;
}
/* 返回的错误码是200表示get成功了,其它代表失败 */
int TLSSocketWrapper::GetMessageCore(const std::string &server_path, std::string &header, std::string &body) {
int error_code;
try {
http::request<http::string_body> req(http::verb::get, server_path, HTTP_11_VERSION);
req.set(http::field::host, ip_);
req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING);
req.set(http::field::keep_alive, DEFAULT_HTTP_ALIVE_CONFIG);
boost::beast::flat_buffer buffer;
http::response<http::dynamic_body> res;
/* 保持锁的时间最短 */
{
write_lock lock(rwlock_);
http::write(*real_socket_, req);
http::read(*real_socket_, buffer, res);
}
std::string response_body{ buffers_begin(res.body().data()),buffers_end(res.body().data()) };
body = response_body;
//http错误码是返回值
error_code = res.result_int();
} catch(std::exception const& e) {
std::cerr << "Get Message Error: " << e.what() << std::endl;
error_code = 500;
}
return error_code;
}
/* 返回的错误码是200表示get成功了,其它代表失败 */
int TLSSocketWrapper::PostMessageCore(const std::string &server_path, const std::string &request_content_type, const std::string &request_body, std::string &header, std::string &body) {
int error_code;
try {
http::request<http::string_body> req(http::verb::post, server_path, HTTP_11_VERSION);
req.set(http::field::host, ip_);
req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING);
req.set(http::field::keep_alive, DEFAULT_HTTP_ALIVE_CONFIG);
req.set(http::field::content_type, request_content_type);
req.body() = request_body;
/* 设定完了使用prepare_payload来更新header里面的content_length */
req.prepare_payload();
boost::beast::flat_buffer buffer;
http::response<http::dynamic_body> res;
/* 保持锁的时间最短 */
{
write_lock lock(rwlock_);
http::write(*real_socket_, req);
http::read(*real_socket_, buffer, res);
}
std::string response_body{ buffers_begin(res.body().data()),buffers_end(res.body().data()) };
body = response_body;
//http错误码是返回值
error_code = res.result_int();
} catch(std::exception const& e) {
std::cerr << "Post Message Error: " << e.what() << std::endl;
error_code = 500;
}
return error_code;
}
void SyncTlsConnectionManager::Init() {
io_service_ = boost::shared_ptr<io_service>(new io_service());
/* 添加work防止io_service 在pending时退出*/
io_service::work work{*io_service_};
this->loop_ = boost::shared_ptr<boost::thread> (new boost::thread([&]{ io_service_->run(); }));
/* lambda需要捕获this才能调用this的函数,如果是静态或者全局函数那么不需要执行捕获 */
timer_ = boost::shared_ptr<RepeatedTimer> (new RepeatedTimer(*io_service_, [this](const boost::system::error_code &e) { this->HttpHeartBeat();}));
/* 15秒钟运行一次心跳保活线程 */
timer_->Start(/*ms=*/1000*15);
}
void SyncTlsConnectionManager::HttpHeartBeat() {
std::vector<std::pair<std::string, boost::shared_ptr<TLSSocketWrapper> > > list;
global_sockets_.GetAllKey(list);
for (auto iter : list) {
size_t pos = iter.first.find('#');
if (pos == iter.first.npos) {
/* 理论上不可能发生这问题 */
continue;
}
std::string ip(iter.first.substr(0, pos));
std::string port(iter.first.substr(pos+1));
/*
SyncTlsClient heartbeat(ip, port, iter.second);
std::string header;
std::string body;
int error_code = heartbeat.HeartBeat("/monitor/alive", header, body);
if (error_code != 200) {
}
*/
}
}
io_context& SyncTlsConnectionManager::GetGlobalIOContext() {
return global_io_context_;
}
ConcurrentLRU<std::string, boost::shared_ptr<TLSSocketWrapper>, TlsWrapperCompare>& SyncTlsConnectionManager::GetGlobalSocketWrapper() {
return global_sockets_;
}
/* 探活/保活接口专用构造函数 */
SyncTlsClient::SyncTlsClient(const std::string &ip, const std::string &port, const boost::shared_ptr<TLSSocketWrapper> &connection):
ip_(ip),
port_(port),
socket_wrapper_(connection) {
;
}
/* 时间戳为证书时间戳,或者说证书文件的时间戳 */
SyncTlsClient::SyncTlsClient(const std::string &ip, const std::string &port, const std::string &cert, const std::string &key, const std::string &ca_cert, const uint64_t &time_stamp) :
ip_(ip),
port_(port) {
std::string ref = ip+"#"+port;
if (SyncTlsConnectionManager::instance().GetGlobalSocketWrapper().Get(ref, socket_wrapper_)) {
/* 能查到且时间戳没超时意味着可用 */
if (time_stamp <= socket_wrapper_->GetTimeStamp()) {
/* 不需要执行重新握手的唯一情况只有这一种,直接使用缓存数据 */
return;
}
}
boost::shared_ptr<TLSSocketWrapper> new_socket_wrapper = boost::shared_ptr<TLSSocketWrapper> (new TLSSocketWrapper(ip, port, SyncTlsConnectionManager::instance().GetGlobalIOContext(), cert, key, ca_cert, time_stamp));
if (new_socket_wrapper != nullptr && new_socket_wrapper->HandshakeFinished()) {
/* 建立连接失败,不会更新全局缓存的,全局缓存只在握手成功的情况下更新*/
SyncTlsConnectionManager::instance().GetGlobalSocketWrapper().Put(ref, new_socket_wrapper);
}
socket_wrapper_ = new_socket_wrapper;
}
/* http心跳报文实际上是个get请求 */
int SyncTlsClient::HeartBeat(const std::string &server_path, std::string &response_header, std::string &response_body) {
int error_code;
/* real_socket_理论来说不能为空,为空直接说明一开始尝试连接失败,直接返回错误码500即可 */
if (socket_wrapper_ != nullptr && socket_wrapper_->HandshakeFinished()) {
/* asio的is_open是个无效函数,因此只能通过尝试发送数据来探测是否该socket还活着/出现异常 */
error_code = socket_wrapper_->GetMessageCore(server_path, response_header, response_body);
/* 500代表连接失效了,需要重新连接,并发起请求 */
if (error_code == 500) {
/* 因为永远保证socket_wrapper的证书/私钥/时间戳都是最新版,所以使用socket_wrapper_的证书私钥建立连接 */
boost::shared_ptr<TLSSocketWrapper> new_socket_wrapper = boost::shared_ptr<TLSSocketWrapper> (new TLSSocketWrapper(ip_, port_, SyncTlsConnectionManager::instance().GetGlobalIOContext(), socket_wrapper_->GetCert(), socket_wrapper_->GetKey(), socket_wrapper_->GetCaCert(), socket_wrapper_->GetTimeStamp()));
if (new_socket_wrapper != nullptr && new_socket_wrapper->HandshakeFinished()) {
error_code = new_socket_wrapper->GetMessageCore(server_path, response_header, response_body);
if (error_code == 200) {
std::string ref = ip_+"#"+port_;
SyncTlsConnectionManager::instance().GetGlobalSocketWrapper().Put(ref, new_socket_wrapper);
socket_wrapper_ = new_socket_wrapper;
}
}
}
} else {
error_code = 500;
}
return error_code;
}
int SyncTlsClient::GetMessage(const std::string &server_path, std::string &response_header, std::string &response_body) {
int error_code;
/* real_socket_理论来说不能为空,为空直接说明一开始尝试连接失败,直接返回错误码500即可 */
if (socket_wrapper_ != nullptr && socket_wrapper_->HandshakeFinished()) {
/* asio的is_open是个无效函数,因此只能通过尝试发送数据来探测是否该socket还活着/出现异常 */
error_code = socket_wrapper_->GetMessageCore(server_path, response_header, response_body);
/* 500代表连接失效了,需要重新连接,并发起请求 */
if (error_code == 500) {
/* 因为永远保证socket_wrapper的证书/私钥/时间戳都是最新版,所以使用socket_wrapper_的证书私钥建立连接 */
boost::shared_ptr<TLSSocketWrapper> new_socket_wrapper = boost::shared_ptr<TLSSocketWrapper> (new TLSSocketWrapper(ip_, port_, SyncTlsConnectionManager::instance().GetGlobalIOContext(), socket_wrapper_->GetCert(), socket_wrapper_->GetKey(), socket_wrapper_->GetCaCert(), socket_wrapper_->GetTimeStamp()));
if (new_socket_wrapper != nullptr && new_socket_wrapper->HandshakeFinished()) {
error_code = new_socket_wrapper->GetMessageCore(server_path, response_header, response_body);
if (error_code == 200) {
std::string ref = ip_+"#"+port_;
SyncTlsConnectionManager::instance().GetGlobalSocketWrapper().Put(ref, new_socket_wrapper);
socket_wrapper_ = new_socket_wrapper;
}
}
}
} else {
error_code = 500;
}
return error_code;
}
int SyncTlsClient::PostMessage(const std::string &server_path, const std::string &request_content_type, const std::string &request_body, std::string &response_header, std::string &response_body) {
int error_code;
/* socket_理论来说不能为空,为空直接说明地址无效了,直接返回错误码500即可 */
if (socket_wrapper_ != nullptr && socket_wrapper_->HandshakeFinished()) {
/* asio的is_open是个无效函数,因此只能通过尝试发送数据来探测是否该socket还活着 */
error_code = socket_wrapper_->PostMessageCore(server_path, request_content_type, request_body, response_header, response_body);
/* 500代表连接失效了,需要重新连接,并发起请求 */
if (error_code == 500) {
/* 因为永远保证socket_wrapper的证书/私钥/时间戳都是最新版,所以使用socket_wrapper_的证书私钥建立连接 */
boost::shared_ptr<TLSSocketWrapper> new_socket_wrapper = boost::shared_ptr<TLSSocketWrapper> (new TLSSocketWrapper(ip_, port_, SyncTlsConnectionManager::instance().GetGlobalIOContext(), socket_wrapper_->GetCert(), socket_wrapper_->GetKey(), socket_wrapper_->GetCaCert(), socket_wrapper_->GetTimeStamp()));
if (new_socket_wrapper != nullptr && new_socket_wrapper->HandshakeFinished()) {
error_code = new_socket_wrapper->PostMessageCore(server_path, request_content_type, request_body, response_header, response_body);
if (error_code == 200) {
std::string ref = ip_+"#"+port_;
SyncTlsConnectionManager::instance().GetGlobalSocketWrapper().Put(ref, new_socket_wrapper);
socket_wrapper_ = new_socket_wrapper;
}
}
}
} else {
error_code = 500;
}
return error_code;
}
5 OPENSSL签名验证签名代码
本来是指用crypto++库做的,然后惊讶的发现crypto++的签名/验签的性能低的令人发指。所以最后又采用了openssl的代码
头文件
#ifndef __SIGN_ALGORITHM_H__
#define __SIGN_ALGORITHM_H__
/* inf cbom的crypto后面没有pp,mac装的需要加个pp */
#ifndef MAC_BUILD
#include <crypto/eccrypto.h>
#include <crypto/osrng.h>
#include <crypto/rsa.h>
#include <crypto/pssr.h>
#include <crypto/base64.h>
#else
#include <cryptopp/eccrypto.h>
#include <cryptopp/osrng.h>
#include <cryptopp/rsa.h>
#include <cryptopp/pssr.h>
#include <cryptopp/base64.h>
#endif
#include <string>
#include <iostream>
#include "sts_enum.h"
#include "openssl/ecdsa.h"
#include "openssl/rsa.h"
#include "openssl/ssl.h"
#include "openssl/ecdh.h"
#include "openssl/evp.h"
#include "openssl/sha.h"
#include "openssl/bio.h"
#include "openssl/pem.h"
#include "openssl/err.h"
#include "openssl/hmac.h"
#define EVP_MAX_SIGNATURE_SIZE 1024
using std::string;
class SignAlgorithm{
public:
static bool sign_with_pem(const int &algorithm, const std::string &message, const std::string &key, std::string &signature);
static bool verify_with_pem(const int &algorithm, const std::string &message, const std::string &key, std::string &signature);
static bool base64_url_encode(const std::string &in, std::string &out);
static bool base64_url_decode(const std::string &in, std::string &out);
static bool base64_encode(const std::string &in, std::string &out);
static bool base64_decode(const std::string &in, std::string &out);
template <typename T>
static bool load_base64_raw_key(const std::string &pem, T &key) {
ByteQueue queue;
Base64Decoder decoder;
decoder.Attach(new Redirector(queue));
decoder.Put((const byte*)pem.data(), pem.length());
decoder.MessageEnd();
try {
key.BERDecode(queue);
#ifdef AUTH_SDK_VALIDATE_KEY
AutoSeededRandomPool prng;
bool valid = key.Validate(prng, 3);
if (!valid) {
return false;
}
#endif
} catch (const Exception& ex) {
return false;
}
return true;
}
private:
/* 公私密钥均为pem格式*/
static bool ecdsa_sha256_sign(const std::string &message, const std::string &pem_private_key, std::string &signature);
static bool ecdsa_sha256_verify(const std::string &message, const std::string &pem_pub_key, std::string &signature);
static bool ecdsa_sha256_sign_openssl(const std::string &message, const std::string &pem_private_key, std::string &signature);
static bool ecdsa_sha256_verify_openssl(const std::string &message, const std::string &pem_pub_key, std::string &signature);
static bool rsa_sha256_sign_openssl(const std::string &message, const std::string &pem_private_key, std::string &signature);
static bool rsa_sha256_verify_openssl(const std::string &message, const std::string &pem_pub_key, std::string &signature);
static bool rsa_sha256_sign(const std::string &message, const std::string &pem_private_key, std::string &signature) ;
static bool rsa_sha256_verify(const std::string &message, const std::string &pem_pub_key, std::string &signature);
static bool hmac_sha256_sign(const std::string &message, const std::string &key, std::string &signature);
static bool hmac_sha256_verify(const std::string &message, const std::string &key, std::string &signature);
static bool hmac_sha256_sign_openssl(const std::string &message, const std::string &key, std::string &signature);
static bool hmac_sha256_verify_openssl(const std::string &message, const std::string &key, std::string &signature);
};
#endif
cpp文件
#include "sign_algorithm.h"
bool SignAlgorithm::sign_with_pem(const int &algorithm, const std::string &message, const std::string &key, std::string &signature) {
switch (algorithm) {
#ifdef USE_CRYPTOPP_ALG
case STS_ALG_TYPE_RS256:
return SignAlgorithm::rsa_sha256_sign(message, key, signature);
case STS_ALG_TYPE_ES256:
return SignAlgorithm::ecdsa_sha256_sign(message, key, signature);
case STS_ALG_TYPE_HS256:
return SignAlgorithm::hmac_sha256_sign(message, key, signature);
#else
case STS_ALG_TYPE_RS256:
return SignAlgorithm::rsa_sha256_sign_openssl(message, key, signature);
case STS_ALG_TYPE_ES256:
return SignAlgorithm::ecdsa_sha256_sign_openssl(message, key, signature);
case STS_ALG_TYPE_HS256:
return SignAlgorithm::hmac_sha256_sign_openssl(message, key, signature);
#endif
}
return false;
}
bool SignAlgorithm::verify_with_pem(const int &algorithm, const std::string &message, const std::string &key, std::string &signature) {
switch (algorithm) {
#ifdef USE_CRYPTOPP_ALG
case STS_ALG_TYPE_RS256:
return SignAlgorithm::rsa_sha256_verify(message, key, signature);
case STS_ALG_TYPE_ES256:
return SignAlgorithm::ecdsa_sha256_verify(message, key, signature);
case STS_ALG_TYPE_HS256:
return SignAlgorithm::hmac_sha256_verify(message, key, signature);
#else
case STS_ALG_TYPE_RS256:
return SignAlgorithm::rsa_sha256_verify_openssl(message, key, signature);
case STS_ALG_TYPE_ES256:
return SignAlgorithm::ecdsa_sha256_verify_openssl(message, key, signature);
case STS_ALG_TYPE_HS256:
return SignAlgorithm::hmac_sha256_verify_openssl(message, key, signature);
#endif
}
return false;
}
bool SignAlgorithm::ecdsa_sha256_sign(const std::string &message, const std::string &pem_private_key, std::string &signature) {
CryptoPP::ECDSA<CryptoPP::ECP, CryptoPP::SHA256>::PrivateKey private_key;
if (!load_base64_raw_key(pem_private_key, private_key)) {
return false;
}
CryptoPP::AutoSeededRandomPool prng;
CryptoPP::ECDSA<CryptoPP::ECP, CryptoPP::SHA256>::Signer signer(private_key);
try {
CryptoPP::StringSource s( message, true /* pump all */,
new CryptoPP::SignerFilter(prng,
signer,
new CryptoPP::StringSink( signature )
)
);
} catch (CryptoPP::Exception e) {
return false;
}
return true;
}
bool SignAlgorithm::ecdsa_sha256_verify(const std::string &message, const std::string &pem_pub_key, std::string &signature) {
CryptoPP::ECDSA<CryptoPP::ECP, CryptoPP::SHA256>::PublicKey public_key;
if (!load_base64_raw_key(pem_pub_key, public_key)) {
return false;
}
bool ret = false;
CryptoPP::ECDSA<CryptoPP::ECP, CryptoPP::SHA256>::Verifier verifier(public_key);
try {
CryptoPP::StringSource s( signature+message, true /* pump all */,
/* 默认就是把校验结果放到ret里 */
new CryptoPP::SignatureVerificationFilter(
verifier,
new CryptoPP::ArraySink( (byte*)&ret, sizeof(ret) )
)
);
} catch (CryptoPP::Exception e) {
ret = false;
}
return ret;
}
bool SignAlgorithm::rsa_sha256_sign(const std::string &message, const std::string &pem_private_key, std::string &signature) {
CryptoPP::RSA::PrivateKey private_key;
if (!load_base64_raw_key(pem_private_key, private_key)) {
return false;
}
CryptoPP::RSASS<CryptoPP::PSS, CryptoPP::SHA256>::Signer signer(private_key);
CryptoPP::AutoSeededRandomPool prng;
bool ret = true;
try {
CryptoPP::StringSource s(message, true,
new CryptoPP::SignerFilter(prng, signer,
new CryptoPP::StringSink( signature )
)
);
} catch (CryptoPP::Exception e) {
ret = false;
}
return ret;
}
bool SignAlgorithm::rsa_sha256_verify(const std::string &message, const std::string &pem_pub_key, std::string &signature) {
CryptoPP::RSA::PublicKey public_key;
if (!load_base64_raw_key(pem_pub_key, public_key)) {
return false;
}
CryptoPP::RSASS<CryptoPP::PSS, CryptoPP::SHA256>::Verifier verifier(public_key);
bool ret = true;
try {
CryptoPP::StringSource s(signature+message, true,
new CryptoPP::SignatureVerificationFilter(
verifier,
new CryptoPP::ArraySink( (byte*)&ret, sizeof(ret) )
)
);
} catch (CryptoPP::Exception e) {
ret = false;
}
return ret;
}
bool SignAlgorithm::hmac_sha256_sign(const std::string &message, const std::string &key, std::string &signature) {
bool ret = true;
try {
/* 输入的密钥是base64编码的,所以需要先解码 */
std::string base64_key;
SignAlgorithm::base64_decode(key, base64_key);
CryptoPP::HMAC<CryptoPP::SHA256> hmac((byte*) base64_key.c_str(), base64_key.length());
CryptoPP::StringSource(message, true, new CryptoPP::HashFilter(hmac, new CryptoPP::StringSink(signature)));
} catch (CryptoPP::Exception e) {
ret = false;
}
return ret;
}
bool SignAlgorithm::hmac_sha256_verify(const std::string &message, const std::string &key, std::string &signature) {
bool ret = true;
std::string computed_signature;
try {
/* 输入的密钥是base64编码的,所以需要先解码 */
std::string base64_key;
SignAlgorithm::base64_decode(key, base64_key);
CryptoPP::HMAC<CryptoPP::SHA256> hmac((byte*) base64_key.c_str(), base64_key.length());
CryptoPP::StringSource(message, true, new CryptoPP::HashFilter(hmac, new CryptoPP::StringSink(computed_signature)));
if (computed_signature.compare(signature) != 0) {
ret = false;
}
} catch (CryptoPP::Exception e) {
ret = false;
}
return ret;
}
bool SignAlgorithm::base64_url_encode(const std::string &in, std::string &out) {
bool ret = true;
out.clear();
try {
CryptoPP::StringSource(in, true, new CryptoPP::Base64URLEncoder(new CryptoPP::StringSink(out)));
} catch (CryptoPP::Exception e) {
ret = false;
}
return ret;
}
bool SignAlgorithm::base64_url_decode(const std::string &in, std::string &out) {
bool ret = true;
out.clear();
try {
CryptoPP::StringSource(in, true, new CryptoPP::Base64URLDecoder(new CryptoPP::StringSink(out)));
} catch (CryptoPP::Exception e) {
ret = false;
}
return ret;
}
bool SignAlgorithm::base64_encode(const std::string &in, std::string &out) {
bool ret = true;
out.clear();
try {
CryptoPP::StringSource(in, true, new CryptoPP::Base64Encoder(new CryptoPP::StringSink(out)));
} catch (CryptoPP::Exception e) {
ret = false;
}
return ret;
}
bool SignAlgorithm::base64_decode(const std::string &in, std::string &out) {
bool ret = true;
out.clear();
try {
CryptoPP::StringSource(in, true, new CryptoPP::Base64Decoder(new CryptoPP::StringSink(out)));
} catch (CryptoPP::Exception e) {
ret = false;
}
return ret;
}
/* 签名的最后结果并不一定小于64,对ECDSA而言256bit可以达到72字节。所以signature_buf长度最好保存1024,直接防备rsa 4096/1024*8的密钥 */
bool SignAlgorithm::ecdsa_sha256_sign_openssl(const std::string &message, const std::string &pem_private_key, std::string &signature) {
uint32_t digest_len = 0;
uint32_t signature_len = 0;
unsigned char digest[EVP_MAX_MD_SIZE];
char signature_buf[EVP_MAX_SIGNATURE_SIZE];
std::string key = "-----BEGIN PRIVATE KEY-----\n" + pem_private_key + "\n-----END PRIVATE KEY-----";
BIO *bio = BIO_new(BIO_s_mem());
if (bio == nullptr) {
return false;
}
BIO_puts(bio, key.c_str());
EC_KEY *ec_key = PEM_read_bio_ECPrivateKey(bio, NULL, NULL, NULL);;
if (ec_key == nullptr) {
BIO_free(bio);
return false;
}
EVP_MD_CTX *md_ctx = EVP_MD_CTX_create();
if (md_ctx == nullptr) {
BIO_free(bio);
EC_KEY_free(ec_key);
return false;
}
bool ret = true;
bzero(digest, EVP_MAX_MD_SIZE);
if (EVP_DigestInit(md_ctx, EVP_sha256()) <= 0
||EVP_DigestUpdate(md_ctx, (const void *)message.c_str(), message.length()) <= 0
||EVP_DigestFinal(md_ctx, digest, &digest_len) <= 0
||ECDSA_sign(0, digest, digest_len, (unsigned char*)signature_buf, &signature_len, ec_key) <= 0) {
ret = false;
}
if (ret == true) {
std::string final_sig(signature_buf, signature_len);
signature = final_sig;
}
BIO_free(bio);
EC_KEY_free(ec_key);
EVP_MD_CTX_destroy(md_ctx);
return ret;
}
bool SignAlgorithm::ecdsa_sha256_verify_openssl(const std::string &message, const std::string &pem_pub_key, std::string &signature) {
uint32_t digest_len = 0;
int signature_len = signature.length();
unsigned char digest[EVP_MAX_MD_SIZE];
std::string key = "-----BEGIN PUBLIC KEY-----\n" + pem_pub_key + "\n-----END PUBLIC KEY-----";
BIO *bio = BIO_new(BIO_s_mem());
if (bio == nullptr) {
return false;
}
BIO_puts(bio, key.c_str());
EC_KEY *ec_key = PEM_read_bio_EC_PUBKEY(bio, NULL, NULL, NULL);;
if (ec_key == nullptr) {
BIO_free(bio);
return false;
}
EVP_MD_CTX *md_ctx = EVP_MD_CTX_create();
if (md_ctx == nullptr) {
BIO_free(bio);
EC_KEY_free(ec_key);
return false;
}
bool ret = true;
bzero(digest, EVP_MAX_MD_SIZE);
if (EVP_DigestInit(md_ctx, EVP_sha256()) <= 0
||EVP_DigestUpdate(md_ctx, (const void *)message.c_str(), message.length()) <= 0
||EVP_DigestFinal(md_ctx, digest, &digest_len) <= 0
||ECDSA_verify(0, digest, digest_len, (const unsigned char*)signature.c_str(), signature_len, ec_key) <= 0) {
ret = false;
}
BIO_free(bio);
EC_KEY_free(ec_key);
EVP_MD_CTX_destroy(md_ctx);
return ret;
}
bool SignAlgorithm::rsa_sha256_sign_openssl(const std::string &message, const std::string &pem_private_key, std::string &signature) {
uint32_t digest_len = 0;
uint32_t signature_len = 0;
unsigned char digest[EVP_MAX_MD_SIZE];
char signature_buf[EVP_MAX_SIGNATURE_SIZE];
std::string key = "-----BEGIN PRIVATE KEY-----\n" + pem_private_key + "\n-----END PRIVATE KEY-----";
BIO *bio = BIO_new(BIO_s_mem());
if (bio == nullptr) {
return false;
}
BIO_puts(bio, key.c_str());
RSA *rsa_key = PEM_read_bio_RSAPrivateKey(bio, NULL, NULL, NULL);;
if (rsa_key == nullptr) {
BIO_free(bio);
return false;
}
EVP_MD_CTX *md_ctx = EVP_MD_CTX_create();
if (md_ctx == nullptr) {
BIO_free(bio);
RSA_free(rsa_key);
return false;
}
bool ret = true;
if (EVP_DigestInit(md_ctx, EVP_sha256()) <= 0
||EVP_DigestUpdate(md_ctx, (const void *)message.c_str(), message.length()) <= 0
||EVP_DigestFinal(md_ctx, digest, &digest_len) <= 0
||RSA_sign(NID_sha256, digest, digest_len, (unsigned char*)signature_buf, &signature_len, rsa_key) <= 0) {
ret = false;
}
if (ret == true) {
std::string final_sig(signature_buf, signature_len);
signature = final_sig;
}
BIO_free(bio);
RSA_free(rsa_key);
EVP_MD_CTX_destroy(md_ctx);
return ret;
}
bool SignAlgorithm::rsa_sha256_verify_openssl(const std::string &message, const std::string &pem_pub_key, std::string &signature) {
uint32_t digest_len = 0;
int signature_len = signature.length();
unsigned char digest[EVP_MAX_MD_SIZE];
std::string key = "-----BEGIN PUBLIC KEY-----\n" + pem_pub_key + "\n-----END PUBLIC KEY-----";
BIO *bio = BIO_new(BIO_s_mem());
if (bio == nullptr) {
return false;
}
BIO_puts(bio, key.c_str());
RSA* rsa_key = PEM_read_bio_RSA_PUBKEY(bio, NULL, NULL, NULL);
if (rsa_key == nullptr) {
BIO_free(bio);
return false;
}
EVP_MD_CTX *md_ctx = EVP_MD_CTX_create();
if (md_ctx == nullptr) {
BIO_free(bio);
RSA_free(rsa_key);
return false;
}
bool ret = true;
if (EVP_DigestInit(md_ctx, EVP_sha256()) <= 0
||EVP_DigestUpdate(md_ctx, (const void *)message.c_str(), message.length())<= 0
||EVP_DigestFinal(md_ctx, digest, &digest_len)<= 0
||RSA_verify(NID_sha256, digest, digest_len, (const unsigned char*)signature.c_str(), signature_len, rsa_key)<= 0) {
ret = false;
}
BIO_free(bio);
RSA_free(rsa_key);
EVP_MD_CTX_destroy(md_ctx);
return ret;
}
bool SignAlgorithm::hmac_sha256_sign_openssl(const std::string &message, const std::string &key, std::string &signature) {
char signature_buf[EVP_MAX_MD_SIZE];
unsigned int signature_len = 0;
HMAC_CTX *hmac = HMAC_CTX_new();
if (hmac == nullptr) {
return false;
}
std::string real_key;
base64_decode(key, real_key);
bool ret = true;
if (HMAC_Init_ex(hmac, real_key.c_str(), real_key.length(), EVP_sha256(), NULL) <= 0
||HMAC_Update(hmac, ( unsigned char* )message.c_str(), message.length()) <= 0
||HMAC_Final(hmac, (unsigned char*)signature_buf, &signature_len) <= 0 ) {
ret = false;
}
if (ret == true) {
std::string final_sig(signature_buf, signature_len);
signature = final_sig;
}
HMAC_CTX_free(hmac);
return ret;
}
bool SignAlgorithm::hmac_sha256_verify_openssl(const std::string &message, const std::string &key, std::string &signature) {
std::string compute_signature;
if (!hmac_sha256_sign_openssl(message, key, compute_signature)) {
return false;
}
if (compute_signature.compare(signature) == 0) {
return true;
}
return false;
}
6 log库
6.1 boost log使用版本
使用boost的log实现的方法,简单来说基本不需要做什么操作,就是初始化完成就可以调用记录日志的函数了。需要注意的是这个是同步的sink
/* .h文件 */
#ifndef __BOOST_BASE_LOG__
#define __BOOST_BASE_LOG__
#include <iostream>
#include <boost/log/core.hpp>
#include <boost/log/trivial.hpp>
#include <boost/log/expressions.hpp>
#include <boost/log/sinks/text_file_backend.hpp>
#include <boost/log/utility/setup/file.hpp>
#include <boost/log/utility/setup/common_attributes.hpp>
#include <boost/log/sources/severity_logger.hpp>
#include <boost/log/sources/record_ostream.hpp>
void init_boost_base_log();
#endif
/* .cpp文件 */
#include "boost_base_log.h"
void init_boost_base_log() {
boost::log::add_file_log(
/* 具体记录日志的名称 */
boost::log::keywords::file_name = "sample_%N.log",
/* 文件10MB的时候rotate一次 */
boost::log::keywords::rotation_size = 10 * 1024 * 1024,
/* 或者每天午夜的时候rotate一次 */
boost::log::keywords::time_based_rotation = boost::log::sinks::file::rotation_at_time_point(0, 0, 0),
/* 消息格式 */
boost::log::keywords::format = "[%TimeStamp%]: %Message%"
);
/* 添加附加属性比方说时间戳,行号,线程号等 */
boost::log::add_common_attributes();
}
/* 具体用法 */
BOOST_LOG_TRIVIAL(trace) << "log message";
BOOST_LOG_TRIVIAL(debug) << "log message";
BOOST_LOG_TRIVIAL(info) << "log message";
BOOST_LOG_TRIVIAL(warning) << "log message";
BOOST_LOG_TRIVIAL(error) << "log message";
BOOST_LOG_TRIVIAL(fatal) << "log message";
/* 编译链接的命令
* 如果使用cmake,需要加一条add_definitions(-DBOOST_LOG_DYN_LINK)
* 然后链接的命令为
* target_link_libraries(boost_log_test -lboost_log-mt -lboost_log_setup-mt -lboost_thread-mt)
*/
6.2 spdlog
#include <memory>
#include <thread>
#include <iostream>
#include "spdlog/spdlog.h"
#include "spdlog/sinks/rotating_file_sink.h"
#include "spdlog/sinks/daily_file_sink.h"
#include "spdlog/sinks/stdout_color_sinks.h"
#if 0
#define LOG_DEBUG(...) SPDLOG_LOGGER_DEBUG(spdlog::default_logger_raw(), __VA_ARGS__);SPDLOG_LOGGER_DEBUG(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_INFO(...) SPDLOG_LOGGER_INFO(spdlog::default_logger_raw(), __VA_ARGS__);SPDLOG_LOGGER_INFO(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_WARN(...) SPDLOG_LOGGER_WARN(spdlog::default_logger_raw(), __VA_ARGS__);SPDLOG_LOGGER_WARN(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_ERROR(...) SPDLOG_LOGGER_ERROR(spdlog::default_logger_raw(), __VA_ARGS__);SPDLOG_LOGGER_ERROR(spdlog::get("file_log"), __VA_ARGS__)
#else
#define LOG_DEBUG(...) SPDLOG_LOGGER_DEBUG(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_INFO(...) SPDLOG_LOGGER_INFO(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_WARN(...) SPDLOG_LOGGER_WARN(spdlog::get("file_log"), __VA_ARGS__)
#define LOG_ERROR(...) SPDLOG_LOGGER_ERROR(spdlog::get("file_log"), __VA_ARGS__)
#endif
std::shared_ptr<spdlog::logger> file_logger;
void init_log() {
file_logger = spdlog::rotating_logger_mt("file_log", "log", 1024 * 1024 * 100, 3);
//auto logger = spdlog::daily_logger_mt("daily_logger", "logs/daily.txt", 2, 30);
/* 遇到warn,冲洗一次*/
//file_logger->flush_on(spdlog::level::warn);
/* 三秒冲洗一次数据 */
spdlog::flush_every(std::chrono::seconds(3));
spdlog::set_pattern("%Y-%m-%d %H:%M:%S [%l] [%t] - <%s>|<%#>|<%!>,%v");
}
void stop_log() {
spdlog::drop_all();
}
7 varint编解码
在实现内部工程的时候遇到一个问题,就是varint编解码的实现,借鉴了网上开源代码,并修复了原本代码只适合unsigned int类型数据bug,并且添加了一个解码时获得varint长度的函数。但是需要注意一点,就是varint解码时,具体解码的对象是否合法需要自己判断,检查的方法也非常简单,无非就是检查解码varint的时候是不少超过5字节,解码varlong是不是超过10字节,或者到边界了,最后一字节第一个位为1,显示还有下个字节。代码差不多就是
int varint_size = 1;
for (int i = 0 ; i < length; ++i) {
if (input[i] & 0x80) {
varint_size += 1;
if (varint_size > 5) {
return false;
}
} else {
break;
}
}
/* 能走到这里说明没超出5的限制,但是也可能是序列到头,即最后一个字符最高位不是0,也不合法 */
if (varint_size == length + 1) {
return false;
}
头文件
#ifndef __VARINT_HPP__
#define __VARINT_HPP__
#include <iostream>
std::size_t varintSize(std::size_t value);
/**
* Encodes an unsigned variable-length integer using the MSB algorithm.
* This function assumes that the value is stored as little endian.
* @param value The input value. Any standard integer type is allowed.
* @param output A pointer to a piece of reserved memory. Must have a minimum size dependent on the input size (32 bit = 5 bytes, 64 bit = 10 bytes).
* @return The number of bytes used in the output memory.
*/
template<typename int_t = uint64_t>
std::size_t encodeVarint(int_t value, uint8_t* output) {
size_t outputSize = 0;
//While more than 7 bits of data are left, occupy the last output byte
// and set the next byte flag
while (value > 127) {
//|128: Set the next byte flag
output[outputSize] = ((uint8_t)(value & 127)) | 128;
//Remove the seven bits we just wrote
value >>= 7;
outputSize++;
}
output[outputSize++] = ((uint8_t)value) & 127;
return outputSize;
}
/**
* Decodes an unsigned variable-length integer using the MSB algorithm.
* @param value A variable-length encoded integer of arbitrary size.
* @param inputSize How many bytes are
*/
template<typename int_t = uint64_t>
int_t decodeVarint(uint8_t* input, std::size_t inputSize) {
int_t ret = 0;
for (size_t i = 0; i < inputSize; i++) {
ret |= (static_cast<size_t>(input[i] & 127)) << (7 * i);
//If the next-byte flag is set
if(!(input[i] & 128)) {
break;
}
}
return ret;
}
#endif
cpp文件
#include "varint.hpp"
std::size_t varintSize(std::size_t value) {
std::size_t n = 1;
while(value > 127)
{
++n;
value /= 128;
}
return n;
}
线程池
先来点基础知识:
C++ FUTURE提供了下面这些类型:
- Providers 类:std::promise, std::package_task
- Futures 类:std::future, shared_future.
- Providers 函数:std::async()
- 其他类型:std::future_error, std::future_errc, std::future_status, std::launch.
std::future && std::packaged_task
什么是std::future? std::future 可以用来获取异步任务的结果,因此可以把它当成一种简单的线程间同步的手段。std::future 通常由某个 Provider 创建,你可以把 Provider 想象成一个异步任务的提供者,Provider 在某个线程中设置共享状态的值,与该共享状态相关联的 std::future 对象调用 get(通常在另外一个线程中) 获取该值,如果共享状态的标志不为 ready,则调用 std::future::get 会阻塞当前的调用者,直到 Provider 设置了共享状态的值(此时共享状态的标志变为 ready),std::future::get 返回异步任务的值或异常(如果发生了异常)。
什么是std::packaged_task? std::packaged_task 包装一个可调用的对象,并且允许异步获取该可调用对象产生的结果,std::packaged_task 与 std::function 类似,只不过 std::packaged_task 将其包装的可调用对象的执行结果通过task.get_future传递给一个 std::future 对象(该对象通常在另外一个线程中获取 std::packaged_task 任务的执行结果)。
有了问题就可以用线程池和异步执行,
#include <memory>
#include <utility>
#include "thread_pool.h"
namespace async {
// Schedule a task to the given thread pool. If thread_pool is null, run the
// task synchronously on the current thread.
template <typename Func, typename... Args>
auto ScheduleFuture(ThreadPool* thread_pool, Func&& f, Args&&... args) {
if (thread_pool == nullptr) {
Future<typename std::result_of<Func(Args...)>::type> future(
std::async(std::launch::deferred, std::forward<Func>(f),
std::forward<Args>(args)...));
future.Wait();
return future;
}
return thread_pool->Schedule(std::forward<Func>(f),
std::forward<Args>(args)...);
}
// Schedule a task to the default thread pool.
template <typename Func, typename... Args>
auto ScheduleFuture(Func&& f, Args&&... args) {
return ScheduleFuture(ThreadPool::DefaultPool(), std::forward<Func>(f),
std::forward<Args>(args)...);
}
// Asynchronously destroy a container.
template <typename ContainerT>
void DestroyContainerAsync(ThreadPool* thread_pool, ContainerT container) {
ScheduleFuture(thread_pool, [_ = std::move(container)]() mutable {
const auto unused = std::move(_);
});
}
// Same, but use the default thread pool.
template <typename ContainerT>
void DestroyContainerAsync(ContainerT container) {
DestroyContainerAsync(ThreadPool::DisposalPool(), std::move(container));
}
template <typename T>
void WaitForFuture(const Future<T>& future) {
future.Wait();
}
} // namespace async
下面看线程池的代码,
class ThreadPool {
public:
explicit ThreadPool(int num_workers,
const std::function<void(int index)>& init_thread = {});
~ThreadPool();
// Return the default thread pool.
static ThreadPool* DefaultPool();
// Return the disposal thread pool to destroy stuff asynchronously.
static ThreadPool* DisposalPool();
ThreadPool(const ThreadPool&) = delete;
ThreadPool& operator=(const ThreadPool&) = delete;
int NumWorkers() const { return workers_.size(); }
template <class Func, class... Args>
using FutureType = Future<typename std::result_of<Func(Args...)>::type>;
// Schedule a new task.
template <class Func, class... Args>
FutureType<Func, Args...> Schedule(Func&& f, Args&&... args)
LOCKS_EXCLUDED(mutex_);
private:
absl::Mutex mutex_;
absl::CondVar cond_var_ GUARDED_BY(mutex_);
std::vector<std::thread> workers_;
// A thread safe queue protected by condition_ and mutex_.
std::queue<std::function<void()>> tasks_ GUARDED_BY(mutex_);
bool stop_requested_ GUARDED_BY(mutex_) = false;
};
template <class Func, class... Args>
ThreadPool::FutureType<Func, Args...> ThreadPool::Schedule(Func&& f,
Args&&... args) {
using ReturnType = typename std::result_of<Func(Args...)>::type;
const auto task = std::make_shared<std::packaged_task<ReturnType()>>(
std::bind(std::forward<Func>(f), std::forward<Args>(args)...));
Future<ReturnType> res(task->get_future());
// If there is no worker, this is an inline thread pool, and the task will be
// immediately run on the current thread.
int64_t tasks_size = -1;
if (workers_.empty()) {
(*task)();
} else {
absl::MutexLock lock(&mutex_);
// QCHECK(!stop_requested_) << "The thread pool has been stopped";
tasks_.emplace([task]() { (*task)(); });
tasks_size = tasks_.size();
cond_var_.Signal();
}
if (tasks_size >= 0) {
QCOUNTER("schedulefuture_callback_size", tasks_size);
}
return res;
}
}
#include "thread_pool.h"
ThreadPool::ThreadPool(int num_workers,
const std::function<void(int index)>& init_thread) {
for (int index = 0; index < num_workers; ++index) {
workers_.emplace_back([this, index, init_thread] {
if (init_thread) {
init_thread(index);
}
while (true) {
std::function<void()> task;
{
absl::MutexLock lock(&mutex_);
while (!stop_requested_ && tasks_.empty()) {
cond_var_.Wait(&mutex_);
}
if (stop_requested_ && tasks_.empty()) {
return;
}
task = std::move(tasks_.front());
tasks_.pop();
}
task();
}
});
}
}
ThreadPool::~ThreadPool() {
{
absl::MutexLock lock(&mutex_);
stop_requested_ = true;
cond_var_.SignalAll();
}
for (std::thread& worker : workers_) {
worker.join();
}
}
// static
ThreadPool* ThreadPool::DefaultPool() {
static ThreadPool* default_pool = new ThreadPool(FLAGS_default_pool_size);
return default_pool;
}
// static
ThreadPool* ThreadPool::DisposalPool() {
static ThreadPool* disposal_pool = new ThreadPool(1);
return disposal_pool;
}
结尾
唉,尴尬