#include "server.h"

#include <cstdint>
#include <iostream>

Server::Server(asio::io_context& io_context, uint16_t port) : io_context_(io_context), acceptor_(io_context)
{
    port_ = port;
}

Server::~Server()
{
    for (auto& client : clients_) {
        client.second.detach();
    }
}

void Server::print_exception(const std::exception& e)
{
    std::cerr << e.what() << '\n';
}

void Server::start()
{
    asio::ip::tcp::endpoint endpoint(asio::ip::tcp::v4(), port_);
    try {
        acceptor_.open(endpoint.protocol());
        // acceptor_.set_option(asio::socket_base::reuse_address(true));
        acceptor_.bind(endpoint);
        acceptor_.listen();
        do_accept();
    } catch (const std::exception& e) {
        print_exception(e);
    }
}

void Server::stop()
{
}

void Server::set_worker(std::shared_ptr<COpenAI> worker, std::shared_ptr<CJsonOper> json)
{
    worker_ = worker;
    json_ = json;
}

void Server::set_token(int32_t tokens)
{
    tokens_ = tokens;
}

void Server::do_accept()
{
    auto socket = std::make_shared<asio::ip::tcp::socket>(io_context_);
    acceptor_.async_accept(*socket, [this, socket](const std::error_code& ec) {
        if (!ec) {
            auto endpoint = socket->remote_endpoint();
            std::string client_key = endpoint.address().to_string() + ":" + std::to_string(endpoint.port());
            std::unique_lock<std::mutex> lock(cli_mutex_);
            client_map_[client_key] = std::make_shared<ClientCache>();
            clients_.insert(std::make_pair(socket->remote_endpoint().address().to_string(),
                                           std::thread([this, socket, client_key]() { th_client(socket, client_key); })));
        }

        do_accept();
    });
}

void Server::th_client(const std::shared_ptr<asio::ip::tcp::socket>& socket, const std::string& client_key)
{

    std::shared_ptr<int> deleter(new int(0), [&](int* p) {
        std::unique_lock<std::mutex> lock(cli_mutex_);
        delete p;
        client_map_.erase(client_key);
        if (clients_.find(client_key) != clients_.end()) {
            clients_.at(client_key).detach();
            clients_.erase(client_key);
        }
        std::cout << "th_client deleter client " << client_key << "exit." << std::endl;
    });

    asio::error_code error;
    std::shared_ptr<ClientCache> cache = nullptr;

    {
        std::unique_lock<std::mutex> lock(cli_mutex_);
        cache = client_map_[client_key];
    }

    while (true) {
        auto len = socket->read_some(asio::buffer(cache->tmp_buf_), error);
        if (error == asio::error::eof) {
            break;   // Connection closed cleanly by peer.
        } else if (error) {
            break;   // Some other error.
        }

        cache->buffer_.push(cache->tmp_buf_.data(), len);

        while (true) {
            auto* frame = com_parse(cache->buffer_);
            if (frame == nullptr) {
                break;
            }
            if (use_tokens_ > tokens_) {
                std::cout << client_key << " tokens not enough" << std::endl;
                FrameData req;
                req.type = FrameType::TYPE_OUT_OF_LIMIT;
                send_frame(socket, req);
                continue;
            }
            std::cout << client_key << " 's data." << std::endl;
            if (frame->type == FrameType::TYPE_REQUEST) {
                ask_mutex_.lock();
                std::string recv_data(frame->data, frame->len);
                std::string out{};
                if (!worker_->post(post_data(recv_data), out)) {
                    std::cout << client_key << " data post error" << std::endl;
                    FrameData req;
                    req.type = FrameType::TYPE_RESPONSE_ERROR;
                    send_frame(socket, req);
                } else {
                    auto parse = json_->parse(out);
                    FrameData req;
                    req.type = FrameType::TYPE_RESPONSE_SUCCESS;
                    req.len = parse.message_content.size();
                    req.data = new char[req.len];
                    req.protk = parse.prompt_tokens;
                    req.coptk = parse.completion_tokens;
                    use_tokens_ += req.protk;
                    use_tokens_ += req.coptk;
                    std::cout << "Already use " << use_tokens_ << " tokens.\n";
                    memcpy(req.data, parse.message_content.c_str(), parse.message_content.size());
                    send_frame(socket, req);
                }
                ask_mutex_.unlock();
            }
            delete frame;
        }
    }
}

std::string Server::post_data(const std::string& data)
{
    return json_->format_request(data);
}

bool Server::send_frame(const std::shared_ptr<asio::ip::tcp::socket>& socket, FrameData& data)
{
    asio::error_code error;
    char* send_data{};
    int len{};

    if (!com_pack(&data, &send_data, len)) {
        return false;
    }

    auto send_len = socket->send(asio::buffer(send_data, len));

    delete[] send_data;
    return send_len == len;
}