Extract read_message/write_message/MessageType into message module.
This commit is contained in:
parent
400be652a5
commit
838207972b
2 changed files with 59 additions and 55 deletions
58
p2pfs/core/message.py
Normal file
58
p2pfs/core/message.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
from enum import Enum, auto
|
||||
import logging
|
||||
import asyncio
|
||||
import struct
|
||||
import json
|
||||
import zstandard as zstd
|
||||
|
||||
_compressor = zstd.ZstdCompressor()
|
||||
_decompressor = zstd.ZstdDecompressor()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
REQUEST_REGISTER = auto()
|
||||
REQUEST_PUBLISH = auto()
|
||||
REQUEST_FILE_LIST = auto()
|
||||
REQUEST_FILE_LOCATION = auto()
|
||||
REQUEST_CHUNK_REGISTER = auto()
|
||||
REPLY_REGISTER = auto()
|
||||
REPLY_FILE_LIST = auto()
|
||||
REPLY_PUBLISH = auto()
|
||||
REPLY_FILE_LOCATION = auto()
|
||||
PEER_REQUEST_CHUNK = auto()
|
||||
PEER_REPLY_CHUNK = auto()
|
||||
PEER_PING_PONG = auto()
|
||||
|
||||
|
||||
def _message_log(message):
|
||||
log_message = {key: message[key] for key in message if key != 'data'}
|
||||
log_message['type'] = MessageType(message['type']).name
|
||||
return log_message
|
||||
|
||||
|
||||
async def read_message(reader):
|
||||
assert isinstance(reader, asyncio.StreamReader)
|
||||
# receive length header -> decompress (bytes) -> decode to str (str) -> json load (dict)
|
||||
raw_msg_len = await reader.readexactly(4)
|
||||
msglen = struct.unpack('>I', raw_msg_len)[0]
|
||||
raw_msg = await reader.readexactly(msglen)
|
||||
|
||||
msg = json.loads(_decompressor.decompress(raw_msg).decode('utf-8'))
|
||||
logger.debug('Message received {}'.format(_message_log(msg)))
|
||||
return msg
|
||||
|
||||
|
||||
async def write_message(writer, message):
|
||||
assert isinstance(writer, asyncio.StreamWriter)
|
||||
logger.debug('Writing {}'.format(_message_log(message)))
|
||||
# use value of enum since Enum is not JSON serializable
|
||||
if isinstance(message['type'], MessageType):
|
||||
message['type'] = message['type'].value
|
||||
# json string (str) -> encode to utf8 (bytes) -> compress (bytes) -> add length header (bytes)
|
||||
raw_msg = json.dumps(message).encode('utf-8')
|
||||
compressed = _compressor.compress(raw_msg)
|
||||
logger.debug('Compressed rate: {}'.format(len(compressed) / len(raw_msg)))
|
||||
compressed = struct.pack('>I', len(compressed)) + compressed
|
||||
writer.write(compressed)
|
||||
await writer.drain()
|
|
@ -1,31 +1,11 @@
|
|||
from abc import abstractmethod
|
||||
import json
|
||||
import struct
|
||||
import logging
|
||||
import asyncio
|
||||
from enum import Enum, auto
|
||||
import zstandard as zstd
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
REQUEST_REGISTER = auto()
|
||||
REQUEST_PUBLISH = auto()
|
||||
REQUEST_FILE_LIST = auto()
|
||||
REQUEST_FILE_LOCATION = auto()
|
||||
REQUEST_CHUNK_REGISTER = auto()
|
||||
REPLY_REGISTER = auto()
|
||||
REPLY_FILE_LIST = auto()
|
||||
REPLY_PUBLISH = auto()
|
||||
REPLY_FILE_LOCATION = auto()
|
||||
PEER_REQUEST_CHUNK = auto()
|
||||
PEER_REPLY_CHUNK = auto()
|
||||
PEER_PING_PONG = auto()
|
||||
|
||||
|
||||
class MessageServer:
|
||||
""" Base class for async TCP server, provides useful _read_message and _write_message methods
|
||||
for transferring message-based packets.
|
||||
""" Base class for async TCP server, provides basic start and stop methods.
|
||||
"""
|
||||
_SOCKET_TIMEOUT = 5
|
||||
|
||||
|
@ -38,9 +18,6 @@ class MessageServer:
|
|||
self._writers = set()
|
||||
self._server = None
|
||||
|
||||
self._compressor = zstd.ZstdCompressor()
|
||||
self._decompressor = zstd.ZstdDecompressor()
|
||||
|
||||
def is_running(self):
|
||||
return self._is_running
|
||||
|
||||
|
@ -70,37 +47,6 @@ class MessageServer:
|
|||
|
||||
self._writers = set()
|
||||
|
||||
@staticmethod
|
||||
def _message_log(message):
|
||||
log_message = {key: message[key] for key in message if key != 'data'}
|
||||
log_message['type'] = MessageType(message['type']).name
|
||||
return log_message
|
||||
|
||||
async def _read_message(self, reader):
|
||||
assert isinstance(reader, asyncio.StreamReader)
|
||||
# receive length header -> decompress (bytes) -> decode to str (str) -> json load (dict)
|
||||
raw_msg_len = await reader.readexactly(4)
|
||||
msglen = struct.unpack('>I', raw_msg_len)[0]
|
||||
raw_msg = await reader.readexactly(msglen)
|
||||
|
||||
msg = json.loads(self._decompressor.decompress(raw_msg).decode('utf-8'))
|
||||
logger.debug('Message received {}'.format(self._message_log(msg)))
|
||||
return msg
|
||||
|
||||
async def _write_message(self, writer, message):
|
||||
assert isinstance(writer, asyncio.StreamWriter)
|
||||
logger.debug('Writing {}'.format(self._message_log(message)))
|
||||
# use value of enum since Enum is not JSON serializable
|
||||
if isinstance(message['type'], MessageType):
|
||||
message['type'] = message['type'].value
|
||||
# json string (str) -> encode to utf8 (bytes) -> compress (bytes) -> add length header (bytes)
|
||||
raw_msg = json.dumps(message).encode('utf-8')
|
||||
compressed = self._compressor.compress(raw_msg)
|
||||
logger.debug('Compressed rate: {}'.format(len(compressed) / len(raw_msg)))
|
||||
compressed = struct.pack('>I', len(compressed)) + compressed
|
||||
writer.write(compressed)
|
||||
await writer.drain()
|
||||
|
||||
async def __new_connection(self, reader, writer):
|
||||
self._writers.add(writer)
|
||||
try:
|
||||
|
|
Loading…
Reference in a new issue