Simplify the base server with the asyncio.
This commit is contained in:
parent
49cd3f5a35
commit
604464c87f
1 changed files with 42 additions and 121 deletions
|
@ -1,151 +1,72 @@
|
|||
import threading
|
||||
import socket
|
||||
from p2pfs.core.message import MessageType
|
||||
from abc import abstractmethod
|
||||
import json
|
||||
import struct
|
||||
import logging
|
||||
import zstandard as zstd
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageServer:
|
||||
""" Base class for async TCP server, provides useful _read_message and _write_message methods
|
||||
for transferring message-based packets.
|
||||
"""
|
||||
_SOCKET_TIMEOUT = 5
|
||||
|
||||
def __init__(self, host, port):
|
||||
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._sock.bind((host, port))
|
||||
self._sock.settimeout(MessageServer._SOCKET_TIMEOUT)
|
||||
|
||||
self._process_lock = threading.Lock()
|
||||
def __init__(self, host, port, loop=None):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._loop = loop
|
||||
|
||||
self._compressor = zstd.ZstdCompressor()
|
||||
self._decompressor = zstd.ZstdDecompressor()
|
||||
|
||||
self._is_running = True
|
||||
|
||||
# manage the connections
|
||||
self._connections_lock = threading.Lock()
|
||||
self._connections = set()
|
||||
self._writers = set()
|
||||
|
||||
# manage the corresponding threads
|
||||
self._threads = set()
|
||||
async def start(self):
|
||||
logger.info('Start listening on {}'.format((self._host, self._port)))
|
||||
# start server
|
||||
await asyncio.start_server(self.__new_connection, self._host, self._port, loop=self._loop)
|
||||
|
||||
def start(self):
|
||||
self._sock.listen(5)
|
||||
logger.info('Start listening on {}'.format(self._sock.getsockname()))
|
||||
# put server listening into a thread
|
||||
thread = threading.Thread(target=self._listen)
|
||||
thread.start()
|
||||
self._threads.add(thread)
|
||||
return True
|
||||
|
||||
def stop(self):
|
||||
# shutdown the server
|
||||
self._is_running = False
|
||||
self._sock.close()
|
||||
# close all connections
|
||||
for client in self._connections:
|
||||
client.close()
|
||||
|
||||
for thread in self._threads:
|
||||
thread.join()
|
||||
|
||||
def _listen(self):
|
||||
try:
|
||||
while self._is_running:
|
||||
try:
|
||||
client, address = self._sock.accept()
|
||||
# add timeout to prevent waiting forever
|
||||
client.settimeout(MessageServer._SOCKET_TIMEOUT)
|
||||
logger.info('New connection from {}'.format(address))
|
||||
with self._connections_lock:
|
||||
self._client_connected(client)
|
||||
self._connections.add(client)
|
||||
thread = threading.Thread(target=self._read_message, args=(client,))
|
||||
thread.start()
|
||||
self._threads.add(thread)
|
||||
except socket.timeout:
|
||||
# ignore timeout exception
|
||||
pass
|
||||
except (ConnectionAbortedError, OSError) as e:
|
||||
if self._is_running:
|
||||
# if exception occurred during normal execution
|
||||
logger.error(e)
|
||||
else:
|
||||
pass
|
||||
|
||||
def _connect(self, address):
|
||||
logger.info('Connecting to {}'.format(address))
|
||||
client = socket.create_connection(address)
|
||||
client.settimeout(MessageServer._SOCKET_TIMEOUT)
|
||||
with self._connections_lock:
|
||||
self._connections.add(client)
|
||||
thread = threading.Thread(target=self._read_message, args=(client,))
|
||||
thread.start()
|
||||
self._threads.add(thread)
|
||||
logger.info('Successfully connected to {} on {}'.format(address, client.getsockname()))
|
||||
return client
|
||||
async def stop(self):
|
||||
for writer in self._writers:
|
||||
writer.close()
|
||||
await writer.wait_close()
|
||||
|
||||
@staticmethod
|
||||
def __message_log(message):
|
||||
return {key: message[key] for key in message if key != 'data'} if 'data' in message else message
|
||||
log_message = {key: message[key] for key in message if key != 'data'}
|
||||
log_message['type'] = MessageType(message['type']).name
|
||||
return log_message
|
||||
|
||||
@staticmethod
|
||||
def __recvall(sock, n):
|
||||
"""helper function to recv n bytes or raise exception if EOF is hit"""
|
||||
data = b''
|
||||
while len(data) < n:
|
||||
try:
|
||||
packet = sock.recv(n - len(data))
|
||||
if not packet:
|
||||
raise EOFError('peer socket closed')
|
||||
data += packet
|
||||
except socket.timeout:
|
||||
pass
|
||||
return data
|
||||
async def _read_message(self, reader):
|
||||
assert isinstance(reader, asyncio.StreamReader)
|
||||
# receive length header -> decompress (bytes) -> decode to str (str) -> json load (dict)
|
||||
raw_msg_len = await reader.readexactly(4)
|
||||
msglen = struct.unpack('>I', raw_msg_len)[0]
|
||||
raw_msg = await reader.readexactly(msglen)
|
||||
msg = json.loads(self._decompressor.decompress(raw_msg).decode('utf-8'))
|
||||
logger.debug('Message received {}'.format(self.__message_log(msg)))
|
||||
return msg
|
||||
|
||||
def _read_message(self, client):
|
||||
assert isinstance(client, socket.socket)
|
||||
try:
|
||||
while True:
|
||||
# receive length header -> decompress (bytes) -> decode to str (str) -> json load (dict)
|
||||
raw_msg_len = self.__recvall(client, 4)
|
||||
msglen = struct.unpack('>I', raw_msg_len)[0]
|
||||
raw_msg = self.__recvall(client, msglen)
|
||||
msg = json.loads(self._decompressor.decompress(raw_msg).decode('utf-8'))
|
||||
logger.debug('Message {} from {}'.format(self.__message_log(msg), client.getpeername()))
|
||||
# process the packets in order
|
||||
# TODO: remove this lock for better parallelism
|
||||
with self._process_lock:
|
||||
self._process_message(client, msg)
|
||||
except (EOFError, OSError):
|
||||
if self._is_running:
|
||||
with self._connections_lock:
|
||||
assert client in self._connections
|
||||
self._connections.remove(client)
|
||||
self._threads.remove(threading.current_thread())
|
||||
self._client_closed(client)
|
||||
client.close()
|
||||
|
||||
def _write_message(self, client, message):
|
||||
assert isinstance(client, socket.socket)
|
||||
logger.debug('Writing {} to {}'.format(self.__message_log(message), client.getpeername()))
|
||||
async def _write_message(self, message, writer):
|
||||
assert isinstance(writer, asyncio.StreamWriter)
|
||||
logger.debug('Writing {}'.format(self.__message_log(message)))
|
||||
# json string (str) -> encode to utf8 (bytes) -> compress (bytes) -> add length header (bytes)
|
||||
raw_msg = json.dumps(message).encode('utf-8')
|
||||
compressed = self._compressor.compress(raw_msg)
|
||||
logger.debug('Compressed rate: {}'.format(len(compressed) / len(raw_msg)))
|
||||
compressed = struct.pack('>I', len(compressed)) + compressed
|
||||
client.setblocking(True)
|
||||
client.sendall(compressed)
|
||||
client.settimeout(MessageServer._SOCKET_TIMEOUT)
|
||||
writer.write(compressed)
|
||||
await writer.drain()
|
||||
|
||||
def _client_connected(self, client):
|
||||
pass
|
||||
async def __new_connection(self, reader, writer):
|
||||
self._writers.add(writer)
|
||||
await self._process_connection(reader, writer)
|
||||
|
||||
@abstractmethod
|
||||
def _process_message(self, client, message):
|
||||
pass
|
||||
|
||||
def _client_closed(self, client):
|
||||
pass
|
||||
|
||||
|
||||
async def _process_connection(self, reader, writer):
|
||||
raise NotImplementedError
|
||||
|
|
Loading…
Reference in a new issue