Simplify the base server with the asyncio.

This commit is contained in:
Yuxin Wang 2018-10-05 18:08:16 -04:00
parent 49cd3f5a35
commit 604464c87f

View file

@ -1,151 +1,72 @@
import threading from p2pfs.core.message import MessageType
import socket
from abc import abstractmethod from abc import abstractmethod
import json import json
import struct import struct
import logging import logging
import zstandard as zstd import zstandard as zstd
import asyncio
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MessageServer: class MessageServer:
""" Base class for async TCP server, provides useful _read_message and _write_message methods
for transferring message-based packets.
"""
_SOCKET_TIMEOUT = 5 _SOCKET_TIMEOUT = 5
def __init__(self, host, port): def __init__(self, host, port, loop=None):
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._host = host
self._sock.bind((host, port)) self._port = port
self._sock.settimeout(MessageServer._SOCKET_TIMEOUT) self._loop = loop
self._process_lock = threading.Lock()
self._compressor = zstd.ZstdCompressor() self._compressor = zstd.ZstdCompressor()
self._decompressor = zstd.ZstdDecompressor() self._decompressor = zstd.ZstdDecompressor()
self._is_running = True
# manage the connections # manage the connections
self._connections_lock = threading.Lock() self._writers = set()
self._connections = set()
# manage the corresponding threads async def start(self):
self._threads = set() logger.info('Start listening on {}'.format((self._host, self._port)))
# start server
await asyncio.start_server(self.__new_connection, self._host, self._port, loop=self._loop)
def start(self): async def stop(self):
self._sock.listen(5) for writer in self._writers:
logger.info('Start listening on {}'.format(self._sock.getsockname())) writer.close()
# put server listening into a thread await writer.wait_close()
thread = threading.Thread(target=self._listen)
thread.start()
self._threads.add(thread)
return True
def stop(self):
# shutdown the server
self._is_running = False
self._sock.close()
# close all connections
for client in self._connections:
client.close()
for thread in self._threads:
thread.join()
def _listen(self):
try:
while self._is_running:
try:
client, address = self._sock.accept()
# add timeout to prevent waiting forever
client.settimeout(MessageServer._SOCKET_TIMEOUT)
logger.info('New connection from {}'.format(address))
with self._connections_lock:
self._client_connected(client)
self._connections.add(client)
thread = threading.Thread(target=self._read_message, args=(client,))
thread.start()
self._threads.add(thread)
except socket.timeout:
# ignore timeout exception
pass
except (ConnectionAbortedError, OSError) as e:
if self._is_running:
# if exception occurred during normal execution
logger.error(e)
else:
pass
def _connect(self, address):
logger.info('Connecting to {}'.format(address))
client = socket.create_connection(address)
client.settimeout(MessageServer._SOCKET_TIMEOUT)
with self._connections_lock:
self._connections.add(client)
thread = threading.Thread(target=self._read_message, args=(client,))
thread.start()
self._threads.add(thread)
logger.info('Successfully connected to {} on {}'.format(address, client.getsockname()))
return client
@staticmethod @staticmethod
def __message_log(message): def __message_log(message):
return {key: message[key] for key in message if key != 'data'} if 'data' in message else message log_message = {key: message[key] for key in message if key != 'data'}
log_message['type'] = MessageType(message['type']).name
return log_message
@staticmethod async def _read_message(self, reader):
def __recvall(sock, n): assert isinstance(reader, asyncio.StreamReader)
"""helper function to recv n bytes or raise exception if EOF is hit""" # receive length header -> decompress (bytes) -> decode to str (str) -> json load (dict)
data = b'' raw_msg_len = await reader.readexactly(4)
while len(data) < n: msglen = struct.unpack('>I', raw_msg_len)[0]
try: raw_msg = await reader.readexactly(msglen)
packet = sock.recv(n - len(data)) msg = json.loads(self._decompressor.decompress(raw_msg).decode('utf-8'))
if not packet: logger.debug('Message received {}'.format(self.__message_log(msg)))
raise EOFError('peer socket closed') return msg
data += packet
except socket.timeout:
pass
return data
def _read_message(self, client): async def _write_message(self, message, writer):
assert isinstance(client, socket.socket) assert isinstance(writer, asyncio.StreamWriter)
try: logger.debug('Writing {}'.format(self.__message_log(message)))
while True:
# receive length header -> decompress (bytes) -> decode to str (str) -> json load (dict)
raw_msg_len = self.__recvall(client, 4)
msglen = struct.unpack('>I', raw_msg_len)[0]
raw_msg = self.__recvall(client, msglen)
msg = json.loads(self._decompressor.decompress(raw_msg).decode('utf-8'))
logger.debug('Message {} from {}'.format(self.__message_log(msg), client.getpeername()))
# process the packets in order
# TODO: remove this lock for better parallelism
with self._process_lock:
self._process_message(client, msg)
except (EOFError, OSError):
if self._is_running:
with self._connections_lock:
assert client in self._connections
self._connections.remove(client)
self._threads.remove(threading.current_thread())
self._client_closed(client)
client.close()
def _write_message(self, client, message):
assert isinstance(client, socket.socket)
logger.debug('Writing {} to {}'.format(self.__message_log(message), client.getpeername()))
# json string (str) -> encode to utf8 (bytes) -> compress (bytes) -> add length header (bytes) # json string (str) -> encode to utf8 (bytes) -> compress (bytes) -> add length header (bytes)
raw_msg = json.dumps(message).encode('utf-8') raw_msg = json.dumps(message).encode('utf-8')
compressed = self._compressor.compress(raw_msg) compressed = self._compressor.compress(raw_msg)
logger.debug('Compressed rate: {}'.format(len(compressed) / len(raw_msg))) logger.debug('Compressed rate: {}'.format(len(compressed) / len(raw_msg)))
compressed = struct.pack('>I', len(compressed)) + compressed compressed = struct.pack('>I', len(compressed)) + compressed
client.setblocking(True) writer.write(compressed)
client.sendall(compressed) await writer.drain()
client.settimeout(MessageServer._SOCKET_TIMEOUT)
def _client_connected(self, client): async def __new_connection(self, reader, writer):
pass self._writers.add(writer)
await self._process_connection(reader, writer)
@abstractmethod @abstractmethod
def _process_message(self, client, message): async def _process_connection(self, reader, writer):
pass raise NotImplementedError
def _client_closed(self, client):
pass