#include "server.h" #include Server::Server(asio::io_context& io_context, short port) : io_context_(io_context), acceptor_(io_context) { port_ = port; } Server::~Server() { for (auto& client : clients_) { client.second.detach(); } } void Server::print_exception(const std::exception& e) { std::cerr << e.what() << '\n'; } void Server::start() { asio::ip::tcp::endpoint endpoint(asio::ip::tcp::v4(), port_); try { acceptor_.open(endpoint.protocol()); // acceptor_.set_option(asio::socket_base::reuse_address(true)); acceptor_.bind(endpoint); acceptor_.listen(); do_accept(); } catch (const std::exception& e) { print_exception(e); } } void Server::stop() { } void Server::set_worker(std::shared_ptr worker, std::shared_ptr json) { worker_ = worker; json_ = json; } void Server::set_token(long tokens) { tokens_ = tokens; } void Server::do_accept() { auto socket = std::make_shared(io_context_); acceptor_.async_accept(*socket, [this, socket](const std::error_code& ec) { if (!ec) { auto endpoint = socket->remote_endpoint(); std::string client_key = endpoint.address().to_string() + ":" + std::to_string(endpoint.port()); std::unique_lock lock(cli_mutex_); client_map_[client_key] = std::make_shared(); clients_.insert( std::make_pair(socket->remote_endpoint().address().to_string(), std::thread([this, socket, client_key]() { th_client(socket, client_key); }))); } do_accept(); }); } void Server::th_client(const std::shared_ptr& socket, const std::string& client_key) { std::shared_ptr deleter(new int(0), [&](int* p) { std::unique_lock lock(cli_mutex_); delete p; client_map_.erase(client_key); if (clients_.find(client_key) != clients_.end()) { clients_.at(client_key).detach(); clients_.erase(client_key); } std::cout << "th_client deleter client " << client_key << "exit." << std::endl; }); asio::error_code error; std::shared_ptr cache = nullptr; { std::unique_lock lock(cli_mutex_); cache = client_map_[client_key]; } while (true) { auto len = socket->read_some(asio::buffer(cache->tmp_buf_), error); if (error == asio::error::eof) { break; // Connection closed cleanly by peer. } else if (error) { break; // Some other error. } cache->buffer_.push(cache->tmp_buf_.data(), len); while (true) { auto frame = com_parse(cache->buffer_); if (frame == nullptr) { break; } if (use_tokens_ > tokens_) { std::cout << client_key << " tokens not enough" << std::endl; FrameData req; req.type = FrameType::TYPE_OUT_OF_LIMIT; send_frame(socket, req); continue; } std::cout << client_key << " 's data." << std::endl; if (frame->type == FrameType::TYPE_REQUEST) { ask_mutex_.lock(); std::string recv_data(frame->data, frame->len); std::string out{}; if (!worker_->post(post_data(recv_data), out)) { std::cout << client_key << " data post error" << std::endl; FrameData req; req.type = FrameType::TYPE_RESPONSE_ERROR; send_frame(socket, req); } else { auto parse = json_->parse(out); FrameData req; req.type = FrameType::TYPE_RESPONSE_SUCCESS; req.len = parse.message_content.size(); req.data = new char[req.len]; req.protk = parse.prompt_tokens; req.coptk = parse.completion_tokens; use_tokens_ += req.protk; use_tokens_ += req.coptk; std::cout << "Already use " << use_tokens_ << " tokens.\n"; memcpy(req.data, parse.message_content.c_str(), parse.message_content.size()); send_frame(socket, req); } ask_mutex_.unlock(); } delete frame; } } } std::string Server::post_data(const std::string& data) { return json_->format_request(data); } bool Server::send_frame(const std::shared_ptr& socket, FrameData& data) { asio::error_code error; char* send_data{}; int len{}; if (!com_pack(&data, &send_data, len)) { return false; } auto send_len = socket->send(asio::buffer(send_data, len)); delete[] send_data; return send_len == len; }