Simplify the base server with the asyncio.

This commit is contained in:
Yuxin Wang 2018-10-05 18:08:16 -04:00
parent 49cd3f5a35
commit 604464c87f

View file

@ -1,151 +1,72 @@
import threading
import socket
from p2pfs.core.message import MessageType
from abc import abstractmethod
import json
import struct
import logging
import zstandard as zstd
import asyncio
logger = logging.getLogger(__name__)
class MessageServer:
""" Base class for async TCP server, provides useful _read_message and _write_message methods
for transferring message-based packets.
"""
_SOCKET_TIMEOUT = 5
def __init__(self, host, port):
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.bind((host, port))
self._sock.settimeout(MessageServer._SOCKET_TIMEOUT)
self._process_lock = threading.Lock()
def __init__(self, host, port, loop=None):
self._host = host
self._port = port
self._loop = loop
self._compressor = zstd.ZstdCompressor()
self._decompressor = zstd.ZstdDecompressor()
self._is_running = True
# manage the connections
self._connections_lock = threading.Lock()
self._connections = set()
self._writers = set()
# manage the corresponding threads
self._threads = set()
async def start(self):
logger.info('Start listening on {}'.format((self._host, self._port)))
# start server
await asyncio.start_server(self.__new_connection, self._host, self._port, loop=self._loop)
def start(self):
self._sock.listen(5)
logger.info('Start listening on {}'.format(self._sock.getsockname()))
# put server listening into a thread
thread = threading.Thread(target=self._listen)
thread.start()
self._threads.add(thread)
return True
def stop(self):
# shutdown the server
self._is_running = False
self._sock.close()
# close all connections
for client in self._connections:
client.close()
for thread in self._threads:
thread.join()
def _listen(self):
try:
while self._is_running:
try:
client, address = self._sock.accept()
# add timeout to prevent waiting forever
client.settimeout(MessageServer._SOCKET_TIMEOUT)
logger.info('New connection from {}'.format(address))
with self._connections_lock:
self._client_connected(client)
self._connections.add(client)
thread = threading.Thread(target=self._read_message, args=(client,))
thread.start()
self._threads.add(thread)
except socket.timeout:
# ignore timeout exception
pass
except (ConnectionAbortedError, OSError) as e:
if self._is_running:
# if exception occurred during normal execution
logger.error(e)
else:
pass
def _connect(self, address):
logger.info('Connecting to {}'.format(address))
client = socket.create_connection(address)
client.settimeout(MessageServer._SOCKET_TIMEOUT)
with self._connections_lock:
self._connections.add(client)
thread = threading.Thread(target=self._read_message, args=(client,))
thread.start()
self._threads.add(thread)
logger.info('Successfully connected to {} on {}'.format(address, client.getsockname()))
return client
async def stop(self):
for writer in self._writers:
writer.close()
await writer.wait_close()
@staticmethod
def __message_log(message):
return {key: message[key] for key in message if key != 'data'} if 'data' in message else message
log_message = {key: message[key] for key in message if key != 'data'}
log_message['type'] = MessageType(message['type']).name
return log_message
@staticmethod
def __recvall(sock, n):
"""helper function to recv n bytes or raise exception if EOF is hit"""
data = b''
while len(data) < n:
try:
packet = sock.recv(n - len(data))
if not packet:
raise EOFError('peer socket closed')
data += packet
except socket.timeout:
pass
return data
async def _read_message(self, reader):
assert isinstance(reader, asyncio.StreamReader)
# receive length header -> decompress (bytes) -> decode to str (str) -> json load (dict)
raw_msg_len = await reader.readexactly(4)
msglen = struct.unpack('>I', raw_msg_len)[0]
raw_msg = await reader.readexactly(msglen)
msg = json.loads(self._decompressor.decompress(raw_msg).decode('utf-8'))
logger.debug('Message received {}'.format(self.__message_log(msg)))
return msg
def _read_message(self, client):
assert isinstance(client, socket.socket)
try:
while True:
# receive length header -> decompress (bytes) -> decode to str (str) -> json load (dict)
raw_msg_len = self.__recvall(client, 4)
msglen = struct.unpack('>I', raw_msg_len)[0]
raw_msg = self.__recvall(client, msglen)
msg = json.loads(self._decompressor.decompress(raw_msg).decode('utf-8'))
logger.debug('Message {} from {}'.format(self.__message_log(msg), client.getpeername()))
# process the packets in order
# TODO: remove this lock for better parallelism
with self._process_lock:
self._process_message(client, msg)
except (EOFError, OSError):
if self._is_running:
with self._connections_lock:
assert client in self._connections
self._connections.remove(client)
self._threads.remove(threading.current_thread())
self._client_closed(client)
client.close()
def _write_message(self, client, message):
assert isinstance(client, socket.socket)
logger.debug('Writing {} to {}'.format(self.__message_log(message), client.getpeername()))
async def _write_message(self, message, writer):
assert isinstance(writer, asyncio.StreamWriter)
logger.debug('Writing {}'.format(self.__message_log(message)))
# json string (str) -> encode to utf8 (bytes) -> compress (bytes) -> add length header (bytes)
raw_msg = json.dumps(message).encode('utf-8')
compressed = self._compressor.compress(raw_msg)
logger.debug('Compressed rate: {}'.format(len(compressed) / len(raw_msg)))
compressed = struct.pack('>I', len(compressed)) + compressed
client.setblocking(True)
client.sendall(compressed)
client.settimeout(MessageServer._SOCKET_TIMEOUT)
writer.write(compressed)
await writer.drain()
def _client_connected(self, client):
pass
async def __new_connection(self, reader, writer):
self._writers.add(writer)
await self._process_connection(reader, writer)
@abstractmethod
def _process_message(self, client, message):
pass
def _client_closed(self, client):
pass
async def _process_connection(self, reader, writer):
raise NotImplementedError