#include "muduo/base/Logging.h" #include "muduo/net/EventLoop.h" #include "muduo/net/InetAddress.h" #include "muduo/net/TcpClient.h" #include "muduo/net/TcpServer.h" #include #include #include #include using namespace muduo; using namespace muduo::net; typedef std::shared_ptr TcpClientPtr; // const int kMaxConns = 1; const size_t kMaxPacketLen = 255; const size_t kHeaderLen = 3; const uint16_t kListenPort = 9999; const char* socksIp = "127.0.0.1"; const uint16_t kSocksPort = 7777; struct Entry { int connId; TcpClientPtr client; TcpConnectionPtr connection; Buffer pending; }; class DemuxServer : noncopyable { public: DemuxServer(EventLoop* loop, const InetAddress& listenAddr, const InetAddress& socksAddr) : loop_(loop), server_(loop, listenAddr, "DemuxServer"), socksAddr_(socksAddr) { server_.setConnectionCallback( std::bind(&DemuxServer::onServerConnection, this, _1)); server_.setMessageCallback( std::bind(&DemuxServer::onServerMessage, this, _1, _2, _3)); } void start() { server_.start(); } void onServerConnection(const TcpConnectionPtr& conn) { if (conn->connected()) { if (serverConn_) { conn->shutdown(); } else { serverConn_ = conn; LOG_INFO << "onServerConnection set serverConn_"; } } else { if (serverConn_ == conn) { serverConn_.reset(); socksConns_.clear(); LOG_INFO << "onServerConnection reset serverConn_"; } } } void onServerMessage(const TcpConnectionPtr& conn, Buffer* buf, Timestamp) { while (buf->readableBytes() > kHeaderLen) { int len = static_cast(*buf->peek()); if (buf->readableBytes() < len + kHeaderLen) { break; } else { int connId = static_cast(buf->peek()[1]); connId |= (static_cast(buf->peek()[2]) << 8); if (connId != 0) { assert(socksConns_.find(connId) != socksConns_.end()); TcpConnectionPtr& socksConn = socksConns_[connId].connection; if (socksConn) { assert(socksConns_[connId].pending.readableBytes() == 0); socksConn->send(buf->peek() + kHeaderLen, len); } else { socksConns_[connId].pending.append(buf->peek() + kHeaderLen, len); } } else { string cmd(buf->peek() + kHeaderLen, len); doCommand(cmd); } buf->retrieve(len + kHeaderLen); } } } void doCommand(const string& cmd) { static const string kConn = "CONN "; int connId = atoi(&cmd[kConn.size()]); bool isUp = cmd.find(" IS UP") != string::npos; LOG_INFO << "doCommand " << connId << " " << isUp; if (isUp) { assert(socksConns_.find(connId) == socksConns_.end()); char connName[256]; snprintf(connName, sizeof connName, "SocksClient %d", connId); Entry entry; entry.connId = connId; entry.client.reset(new TcpClient(loop_, socksAddr_, connName)); entry.client->setConnectionCallback( std::bind(&DemuxServer::onSocksConnection, this, connId, _1)); entry.client->setMessageCallback( std::bind(&DemuxServer::onSocksMessage, this, connId, _1, _2, _3)); // FIXME: setWriteCompleteCallback socksConns_[connId] = entry; entry.client->connect(); } else { assert(socksConns_.find(connId) != socksConns_.end()); TcpConnectionPtr& socksConn = socksConns_[connId].connection; if (socksConn) { socksConn->shutdown(); } else { socksConns_.erase(connId); } } } void onSocksConnection(int connId, const TcpConnectionPtr& conn) { assert(socksConns_.find(connId) != socksConns_.end()); if (conn->connected()) { socksConns_[connId].connection = conn; Buffer& pendingData = socksConns_[connId].pending; if (pendingData.readableBytes() > 0) { conn->send(&pendingData); } } else { if (serverConn_) { char buf[256]; int len = snprintf(buf, sizeof(buf), "DISCONNECT %d\r\n", connId); Buffer buffer; buffer.append(buf, len); sendServerPacket(0, &buffer); } else { socksConns_.erase(connId); } } } void onSocksMessage(int connId, const TcpConnectionPtr& conn, Buffer* buf, Timestamp) { assert(socksConns_.find(connId) != socksConns_.end()); while (buf->readableBytes() > kMaxPacketLen) { Buffer packet; packet.append(buf->peek(), kMaxPacketLen); buf->retrieve(kMaxPacketLen); sendServerPacket(connId, &packet); } if (buf->readableBytes() > 0) { sendServerPacket(connId, buf); } } void sendServerPacket(int connId, Buffer* buf) { size_t len = buf->readableBytes(); LOG_DEBUG << len; assert(len <= kMaxPacketLen); uint8_t header[kHeaderLen] = { static_cast(len), static_cast(connId & 0xFF), static_cast((connId & 0xFF00) >> 8) }; buf->prepend(header, kHeaderLen); if (serverConn_) { serverConn_->send(buf); } } EventLoop* loop_; TcpServer server_; TcpConnectionPtr serverConn_; const InetAddress socksAddr_; std::map socksConns_; }; int main(int argc, char* argv[]) { LOG_INFO << "pid = " << getpid(); EventLoop loop; InetAddress listenAddr(kListenPort); if (argc > 1) { socksIp = argv[1]; } InetAddress socksAddr(socksIp, kSocksPort); DemuxServer server(&loop, listenAddr, socksAddr); server.start(); loop.loop(); }