Make Peer asynchrounous.

This commit is contained in:
Yuxin Wang 2018-10-06 14:18:15 -04:00
parent 2d71e1581d
commit e2f68790b7

View file

@ -8,6 +8,7 @@ import math
import pybase64 import pybase64
import json import json
import hashlib import hashlib
import asyncio
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,184 +16,165 @@ class Peer(MessageServer):
_CHUNK_SIZE = 512 * 1024 _CHUNK_SIZE = 512 * 1024
_HASH_FUNC = hashlib.sha256 _HASH_FUNC = hashlib.sha256
def __init__(self, host, port, server, server_port): def __init__(self, host, port, server, server_port, loop=None):
super().__init__(host, port) super().__init__(host, port, loop=loop)
self._serverconfig = (server, server_port) self._server_config = (server, server_port)
self._server_sock = None self._server_reader, self._server_writer = None, None
# (remote filename) <-> (local filename) # (remote filename) <-> (local filename)
self._file_map = {} self._file_map = {}
# lock and results for publish method self._pending_publish = set()
self._publish_lock = threading.Lock()
self._publish_results = {}
# lock and results for list_file method async def start(self):
self._file_list = None
self._file_list_lock = threading.Lock()
self._file_list_result = Queue()
# lock and results for download
self._download_lock = threading.Lock()
self._download_results = {}
def start(self):
# connect to server # connect to server
try: try:
self._server_sock = self._connect(self._serverconfig) self._server_reader, self._server_writer = \
await asyncio.open_connection(*self._server_config, loop=self._loop)
except ConnectionRefusedError: except ConnectionRefusedError:
logger.error('Server connection refused!') logger.error('Server connection refused!')
return False return False
# start the internal server # start the internal server
super().start() await super().start()
# send out register message # send out register message
logger.info('Requesting to register') logger.info('Requesting to register')
self._write_message(self._server_sock, { await self._write_message(self._server_writer, {
'type': MessageType.REQUEST_REGISTER, 'type': MessageType.REQUEST_REGISTER,
'address': self._sock.getsockname() 'address': self._server_config
}) })
message = await self._read_message(self._server_reader)
assert MessageType(message['type']) == MessageType.REPLY_REGISTER
logger.info('Successfully registered.')
return True return True
def publish(self, file): async def publish(self, local_file, remote_name=None):
if not os.path.exists(file): if not os.path.exists(local_file):
return False, 'File {} doesn\'t exist'.format(file) return False, 'File {} doesn\'t exist'.format(local_file)
path, filename = os.path.split(file) _, remote_name = os.path.split(local_file) if remote_name is None else remote_name
# guard the check to prevent 2 threads passing the check simultaneously
with self._publish_lock: if remote_name in self._pending_publish:
if filename in self._publish_results: return False, 'Publish file {} already in progress.'.format(local_file)
return False, 'Publish file {} already in progress.'.format(file)
self._publish_results[filename] = Queue(maxsize=1) self._pending_publish.add(remote_name)
# send out the request packet # send out the request packet
self._write_message(self._server_sock, { await self._write_message(self._server_writer, {
'type': MessageType.REQUEST_PUBLISH, 'type': MessageType.REQUEST_PUBLISH,
'filename': filename, 'filename': remote_name,
'fileinfo': {'size': os.stat(file).st_size}, 'fileinfo': {'size': os.stat(local_file).st_size},
'chunknum': math.ceil(os.stat(file).st_size / Peer._CHUNK_SIZE) 'chunknum': math.ceil(os.stat(local_file).st_size / Peer._CHUNK_SIZE)
}) })
# queue will block until the result is ready message = await self._read_message(self._server_reader)
is_success, message = self._publish_results[filename].get() assert MessageType(message['type']) == MessageType.REPLY_PUBLISH
if is_success: is_success, message = message['result'], message['message']
self._file_map[filename] = file
logger.info('File {} published on server with name {}'.format(file, filename))
else:
logger.info('File {} failed to publish, {}'.format(file, message))
# remove result if is_success:
with self._publish_lock: self._file_map[remote_name] = local_file
del self._publish_results[filename] logger.info('File {} published on server with name {}'.format(local_file, remote_name))
else:
logger.info('File {} failed to publish, {}'.format(local_file, message))
self._pending_publish.remove(remote_name)
return is_success, message return is_success, message
def list_file(self): async def list_file(self):
self._write_message(self._server_sock, { await self._write_message(self._server_writer, {
'type': MessageType.REQUEST_FILE_LIST, 'type': MessageType.REQUEST_FILE_LIST,
}) })
with self._file_list_lock: message = await self._read_message(self._server_reader)
self._file_list = self._file_list_result.get() assert MessageType(message['type']) == MessageType.REPLY_FILE_LIST
return self._file_list return message['file_list']
def download(self, file, destination, reporthook=None): async def download(self, file, destination, reporthook=None):
with self._file_list_lock: # request for file list
if self._file_list is None or file not in self._file_list.keys(): file_list = await self.list_file()
return False, 'Requested file {} does not exist, try list_file?'.format(file) if file not in file_list:
with self._download_lock: return False, 'Requested file {} does not exist, try list_file?'.format(file)
if file in self._download_results:
return False, 'Download {} already in progress.'.format(file)
self._download_results[file] = Queue()
self._write_message(self._server_sock, { await self._write_message(self._server_writer, {
'type': MessageType.REQUEST_FILE_LOCATION, 'type': MessageType.REQUEST_FILE_LOCATION,
'filename': file 'filename': file
}) })
# wait until reply is ready
fileinfo, chunkinfo = self._download_results[file].get() message = await self._read_message(self._server_reader)
totalchunknum = math.ceil(fileinfo['size'] / Peer._CHUNK_SIZE) assert MessageType(message['type']) == MessageType.REPLY_FILE_LOCATION
fileinfo, chunkinfo = message['fileinfo'], message['chunkinfo']
logger.debug('{}: {} ==> {}'.format(file, fileinfo, chunkinfo)) logger.debug('{}: {} ==> {}'.format(file, fileinfo, chunkinfo))
totalchunknum = math.ceil(fileinfo['size'] / Peer._CHUNK_SIZE)
# TODO: decide which peer to request chunk # TODO: decide which peer to request chunk
peers = {} peers = {}
try: # TODO: make it parallel
for chunknum in range(totalchunknum): for chunknum in range(totalchunknum):
for peer_address, possessed_chunks in chunkinfo.items(): for peer_address, possessed_chunks in chunkinfo.items():
if chunknum in possessed_chunks: if chunknum in possessed_chunks:
if peer_address not in peers: if peer_address not in peers:
# peer_address is a string, since JSON requires keys being strings # peer_address is a string, since JSON requires keys being strings
peers[peer_address] = self._connect(json.loads(peer_address)) peers[peer_address] = await asyncio.open_connection(*json.loads(peer_address), loop=self._loop)
# write the message to ask the chunk # write the message to ask the chunk
self._write_message(peers[peer_address], { await self._write_message(peers[peer_address][1], {
'type': MessageType.PEER_REQUEST_CHUNK, 'type': MessageType.PEER_REQUEST_CHUNK,
'filename': file, 'filename': file,
'chunknum': chunknum 'chunknum': chunknum
}) })
break break
# TODO: update chunkinfo after receiving each chunk # TODO: update chunkinfo after receiving each chunk
with open(destination + '.temp', 'wb') as dest_file: with open(destination + '.temp', 'wb') as dest_file:
self._file_map[file] = destination self._file_map[file] = destination
for i in range(totalchunknum): for i in range(totalchunknum):
number, data, digest = self._download_results[file].get() for address, (reader, _) in peers:
raw_data = pybase64.b64decode(data.encode('utf-8'), validate=True) assert isinstance(reader, asyncio.StreamReader)
# TODO: handle if corrupted while not reader.at_eof():
if Peer._HASH_FUNC(raw_data).hexdigest() != digest: message = await self._read_message(reader)
assert False number, data, digest = message['chunknum'], message['data'], message['digest']
dest_file.seek(number * Peer._CHUNK_SIZE, 0) raw_data = pybase64.b64decode(data.encode('utf-8'), validate=True)
dest_file.write(raw_data) # TODO: handle if corrupted
dest_file.flush() if Peer._HASH_FUNC(raw_data).hexdigest() != digest:
# send request chunk register to server assert False
self._write_message(self._server_sock, { dest_file.seek(number * Peer._CHUNK_SIZE, 0)
'type': MessageType.REQUEST_CHUNK_REGISTER, dest_file.write(raw_data)
'filename': file, dest_file.flush()
'chunknum': number # send request chunk register to server
}) await self._write_message(self._server_writer, {
if reporthook: 'type': MessageType.REQUEST_CHUNK_REGISTER,
reporthook(i + 1, Peer._CHUNK_SIZE, fileinfo['size']) 'filename': file,
logger.debug('Got {}\'s chunk # {}'.format(file, number)) 'chunknum': number
})
if reporthook:
reporthook(i + 1, Peer._CHUNK_SIZE, fileinfo['size'])
logger.debug('Got {}\'s chunk # {}'.format(file, number))
# change the temp file into the actual file # change the temp file into the actual file
os.rename(destination + '.temp', destination) os.rename(destination + '.temp', destination)
with self._download_lock: # close the connections
del self._download_results[file] for _, (_, writer) in peers:
writer.close()
finally: await writer.wait_closed()
# close the sockets no matter what happens
for _, client in peers.items():
client.close()
return True, 'File {} dowloaded to {}'.format(file, destination) return True, 'File {} dowloaded to {}'.format(file, destination)
def _process_message(self, client, message): async def _process_connection(self, reader, writer):
if message['type'] == MessageType.REPLY_REGISTER: assert isinstance(reader, asyncio.StreamReader) and isinstance(writer, asyncio.StreamWriter)
logger.info('Successfully registered.') while not reader.at_eof():
elif message['type'] == MessageType.REPLY_PUBLISH: message = await self._read_message(reader)
self._publish_results[message['filename']].put((message['result'], message['message'])) message_type = MessageType(message['type'])
elif message['type'] == MessageType.REPLY_FILE_LIST: if message_type == MessageType.PEER_REQUEST_CHUNK:
self._file_list_result.put(message['file_list']) assert message['filename'] in self._file_map, 'File {} requested does not exist'.format(message['filename'])
elif message['type'] == MessageType.REPLY_FILE_LOCATION: local_file = self._file_map[message['filename']]
self._download_results[message['filename']].put((message['fileinfo'], message['chunkinfo'])) with open(local_file, 'rb') as f:
elif message['type'] == MessageType.PEER_REQUEST_CHUNK: f.seek(message['chunknum'] * Peer._CHUNK_SIZE, 0)
assert message['filename'] in self._file_map, 'File {} requested does not exist'.format(message['filename']) raw_data = f.read(Peer._CHUNK_SIZE)
local_file = self._file_map[message['filename']] await self._write_message(writer, {
with open(local_file, 'rb') as f: 'type': MessageType.PEER_REPLY_CHUNK,
f.seek(message['chunknum'] * Peer._CHUNK_SIZE, 0) 'filename': message['filename'],
raw_data = f.read(Peer._CHUNK_SIZE) 'chunknum': message['chunknum'],
self._write_message(client, { 'data': pybase64.b64encode(raw_data).decode('utf-8'),
'type': MessageType.PEER_REPLY_CHUNK, 'digest': Peer._HASH_FUNC(raw_data).hexdigest()
'filename': message['filename'], })
'chunknum': message['chunknum'], else:
'data': pybase64.b64encode(raw_data).decode('utf-8'), logger.error('Undefined message with type {}, full message: {}'.format(message['type'], message))
'digest': Peer._HASH_FUNC(raw_data).hexdigest()
})
elif message['type'] == MessageType.PEER_REPLY_CHUNK:
self._download_results[message['filename']].put((message['chunknum'], message['data'], message['digest']))
else:
logger.error('Undefined message with type {}, full message: {}'.format(message['type'], message))
def _client_closed(self, client):
# TODO: hanlde client closed unexpectedly
assert isinstance(client, socket.socket)
if client is self._server_sock:
logger.error('Server {} closed unexpectedly'.format(client.getpeername()))
exit(1)