Extract read_message/write_message/MessageType into message module.

This commit is contained in:
Yuxin Wang 2018-10-11 21:30:01 -04:00
parent 400be652a5
commit 838207972b
2 changed files with 59 additions and 55 deletions

58
p2pfs/core/message.py Normal file
View file

@ -0,0 +1,58 @@
from enum import Enum, auto
import logging
import asyncio
import struct
import json
import zstandard as zstd
_compressor = zstd.ZstdCompressor()
_decompressor = zstd.ZstdDecompressor()
logger = logging.getLogger(__name__)
class MessageType(Enum):
REQUEST_REGISTER = auto()
REQUEST_PUBLISH = auto()
REQUEST_FILE_LIST = auto()
REQUEST_FILE_LOCATION = auto()
REQUEST_CHUNK_REGISTER = auto()
REPLY_REGISTER = auto()
REPLY_FILE_LIST = auto()
REPLY_PUBLISH = auto()
REPLY_FILE_LOCATION = auto()
PEER_REQUEST_CHUNK = auto()
PEER_REPLY_CHUNK = auto()
PEER_PING_PONG = auto()
def _message_log(message):
log_message = {key: message[key] for key in message if key != 'data'}
log_message['type'] = MessageType(message['type']).name
return log_message
async def read_message(reader):
assert isinstance(reader, asyncio.StreamReader)
# receive length header -> decompress (bytes) -> decode to str (str) -> json load (dict)
raw_msg_len = await reader.readexactly(4)
msglen = struct.unpack('>I', raw_msg_len)[0]
raw_msg = await reader.readexactly(msglen)
msg = json.loads(_decompressor.decompress(raw_msg).decode('utf-8'))
logger.debug('Message received {}'.format(_message_log(msg)))
return msg
async def write_message(writer, message):
assert isinstance(writer, asyncio.StreamWriter)
logger.debug('Writing {}'.format(_message_log(message)))
# use value of enum since Enum is not JSON serializable
if isinstance(message['type'], MessageType):
message['type'] = message['type'].value
# json string (str) -> encode to utf8 (bytes) -> compress (bytes) -> add length header (bytes)
raw_msg = json.dumps(message).encode('utf-8')
compressed = _compressor.compress(raw_msg)
logger.debug('Compressed rate: {}'.format(len(compressed) / len(raw_msg)))
compressed = struct.pack('>I', len(compressed)) + compressed
writer.write(compressed)
await writer.drain()

View file

@ -1,31 +1,11 @@
from abc import abstractmethod
import json
import struct
import logging
import asyncio
from enum import Enum, auto
import zstandard as zstd
logger = logging.getLogger(__name__)
class MessageType(Enum):
REQUEST_REGISTER = auto()
REQUEST_PUBLISH = auto()
REQUEST_FILE_LIST = auto()
REQUEST_FILE_LOCATION = auto()
REQUEST_CHUNK_REGISTER = auto()
REPLY_REGISTER = auto()
REPLY_FILE_LIST = auto()
REPLY_PUBLISH = auto()
REPLY_FILE_LOCATION = auto()
PEER_REQUEST_CHUNK = auto()
PEER_REPLY_CHUNK = auto()
PEER_PING_PONG = auto()
class MessageServer:
""" Base class for async TCP server, provides useful _read_message and _write_message methods
for transferring message-based packets.
""" Base class for async TCP server, provides basic start and stop methods.
"""
_SOCKET_TIMEOUT = 5
@ -38,9 +18,6 @@ class MessageServer:
self._writers = set()
self._server = None
self._compressor = zstd.ZstdCompressor()
self._decompressor = zstd.ZstdDecompressor()
def is_running(self):
return self._is_running
@ -70,37 +47,6 @@ class MessageServer:
self._writers = set()
@staticmethod
def _message_log(message):
log_message = {key: message[key] for key in message if key != 'data'}
log_message['type'] = MessageType(message['type']).name
return log_message
async def _read_message(self, reader):
assert isinstance(reader, asyncio.StreamReader)
# receive length header -> decompress (bytes) -> decode to str (str) -> json load (dict)
raw_msg_len = await reader.readexactly(4)
msglen = struct.unpack('>I', raw_msg_len)[0]
raw_msg = await reader.readexactly(msglen)
msg = json.loads(self._decompressor.decompress(raw_msg).decode('utf-8'))
logger.debug('Message received {}'.format(self._message_log(msg)))
return msg
async def _write_message(self, writer, message):
assert isinstance(writer, asyncio.StreamWriter)
logger.debug('Writing {}'.format(self._message_log(message)))
# use value of enum since Enum is not JSON serializable
if isinstance(message['type'], MessageType):
message['type'] = message['type'].value
# json string (str) -> encode to utf8 (bytes) -> compress (bytes) -> add length header (bytes)
raw_msg = json.dumps(message).encode('utf-8')
compressed = self._compressor.compress(raw_msg)
logger.debug('Compressed rate: {}'.format(len(compressed) / len(raw_msg)))
compressed = struct.pack('>I', len(compressed)) + compressed
writer.write(compressed)
await writer.drain()
async def __new_connection(self, reader, writer):
self._writers.add(writer)
try: