Make Peer asynchrounous.
This commit is contained in:
parent
2d71e1581d
commit
e2f68790b7
1 changed files with 115 additions and 133 deletions
|
@ -8,6 +8,7 @@ import math
|
||||||
import pybase64
|
import pybase64
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import asyncio
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,184 +16,165 @@ class Peer(MessageServer):
|
||||||
_CHUNK_SIZE = 512 * 1024
|
_CHUNK_SIZE = 512 * 1024
|
||||||
_HASH_FUNC = hashlib.sha256
|
_HASH_FUNC = hashlib.sha256
|
||||||
|
|
||||||
def __init__(self, host, port, server, server_port):
|
def __init__(self, host, port, server, server_port, loop=None):
|
||||||
super().__init__(host, port)
|
super().__init__(host, port, loop=loop)
|
||||||
self._serverconfig = (server, server_port)
|
self._server_config = (server, server_port)
|
||||||
self._server_sock = None
|
self._server_reader, self._server_writer = None, None
|
||||||
|
|
||||||
# (remote filename) <-> (local filename)
|
# (remote filename) <-> (local filename)
|
||||||
self._file_map = {}
|
self._file_map = {}
|
||||||
|
|
||||||
# lock and results for publish method
|
self._pending_publish = set()
|
||||||
self._publish_lock = threading.Lock()
|
|
||||||
self._publish_results = {}
|
|
||||||
|
|
||||||
# lock and results for list_file method
|
async def start(self):
|
||||||
self._file_list = None
|
|
||||||
self._file_list_lock = threading.Lock()
|
|
||||||
self._file_list_result = Queue()
|
|
||||||
|
|
||||||
# lock and results for download
|
|
||||||
self._download_lock = threading.Lock()
|
|
||||||
self._download_results = {}
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
# connect to server
|
# connect to server
|
||||||
try:
|
try:
|
||||||
self._server_sock = self._connect(self._serverconfig)
|
self._server_reader, self._server_writer = \
|
||||||
|
await asyncio.open_connection(*self._server_config, loop=self._loop)
|
||||||
except ConnectionRefusedError:
|
except ConnectionRefusedError:
|
||||||
logger.error('Server connection refused!')
|
logger.error('Server connection refused!')
|
||||||
return False
|
return False
|
||||||
# start the internal server
|
# start the internal server
|
||||||
super().start()
|
await super().start()
|
||||||
# send out register message
|
# send out register message
|
||||||
logger.info('Requesting to register')
|
logger.info('Requesting to register')
|
||||||
self._write_message(self._server_sock, {
|
await self._write_message(self._server_writer, {
|
||||||
'type': MessageType.REQUEST_REGISTER,
|
'type': MessageType.REQUEST_REGISTER,
|
||||||
'address': self._sock.getsockname()
|
'address': self._server_config
|
||||||
})
|
})
|
||||||
|
message = await self._read_message(self._server_reader)
|
||||||
|
assert MessageType(message['type']) == MessageType.REPLY_REGISTER
|
||||||
|
logger.info('Successfully registered.')
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def publish(self, file):
|
async def publish(self, local_file, remote_name=None):
|
||||||
if not os.path.exists(file):
|
if not os.path.exists(local_file):
|
||||||
return False, 'File {} doesn\'t exist'.format(file)
|
return False, 'File {} doesn\'t exist'.format(local_file)
|
||||||
|
|
||||||
path, filename = os.path.split(file)
|
_, remote_name = os.path.split(local_file) if remote_name is None else remote_name
|
||||||
# guard the check to prevent 2 threads passing the check simultaneously
|
|
||||||
with self._publish_lock:
|
if remote_name in self._pending_publish:
|
||||||
if filename in self._publish_results:
|
return False, 'Publish file {} already in progress.'.format(local_file)
|
||||||
return False, 'Publish file {} already in progress.'.format(file)
|
|
||||||
self._publish_results[filename] = Queue(maxsize=1)
|
self._pending_publish.add(remote_name)
|
||||||
|
|
||||||
# send out the request packet
|
# send out the request packet
|
||||||
self._write_message(self._server_sock, {
|
await self._write_message(self._server_writer, {
|
||||||
'type': MessageType.REQUEST_PUBLISH,
|
'type': MessageType.REQUEST_PUBLISH,
|
||||||
'filename': filename,
|
'filename': remote_name,
|
||||||
'fileinfo': {'size': os.stat(file).st_size},
|
'fileinfo': {'size': os.stat(local_file).st_size},
|
||||||
'chunknum': math.ceil(os.stat(file).st_size / Peer._CHUNK_SIZE)
|
'chunknum': math.ceil(os.stat(local_file).st_size / Peer._CHUNK_SIZE)
|
||||||
})
|
})
|
||||||
|
|
||||||
# queue will block until the result is ready
|
message = await self._read_message(self._server_reader)
|
||||||
is_success, message = self._publish_results[filename].get()
|
assert MessageType(message['type']) == MessageType.REPLY_PUBLISH
|
||||||
if is_success:
|
is_success, message = message['result'], message['message']
|
||||||
self._file_map[filename] = file
|
|
||||||
logger.info('File {} published on server with name {}'.format(file, filename))
|
|
||||||
else:
|
|
||||||
logger.info('File {} failed to publish, {}'.format(file, message))
|
|
||||||
|
|
||||||
# remove result
|
if is_success:
|
||||||
with self._publish_lock:
|
self._file_map[remote_name] = local_file
|
||||||
del self._publish_results[filename]
|
logger.info('File {} published on server with name {}'.format(local_file, remote_name))
|
||||||
|
else:
|
||||||
|
logger.info('File {} failed to publish, {}'.format(local_file, message))
|
||||||
|
|
||||||
|
self._pending_publish.remove(remote_name)
|
||||||
return is_success, message
|
return is_success, message
|
||||||
|
|
||||||
def list_file(self):
|
async def list_file(self):
|
||||||
self._write_message(self._server_sock, {
|
await self._write_message(self._server_writer, {
|
||||||
'type': MessageType.REQUEST_FILE_LIST,
|
'type': MessageType.REQUEST_FILE_LIST,
|
||||||
})
|
})
|
||||||
with self._file_list_lock:
|
message = await self._read_message(self._server_reader)
|
||||||
self._file_list = self._file_list_result.get()
|
assert MessageType(message['type']) == MessageType.REPLY_FILE_LIST
|
||||||
return self._file_list
|
return message['file_list']
|
||||||
|
|
||||||
def download(self, file, destination, reporthook=None):
|
async def download(self, file, destination, reporthook=None):
|
||||||
with self._file_list_lock:
|
# request for file list
|
||||||
if self._file_list is None or file not in self._file_list.keys():
|
file_list = await self.list_file()
|
||||||
return False, 'Requested file {} does not exist, try list_file?'.format(file)
|
if file not in file_list:
|
||||||
with self._download_lock:
|
return False, 'Requested file {} does not exist, try list_file?'.format(file)
|
||||||
if file in self._download_results:
|
|
||||||
return False, 'Download {} already in progress.'.format(file)
|
|
||||||
self._download_results[file] = Queue()
|
|
||||||
|
|
||||||
self._write_message(self._server_sock, {
|
await self._write_message(self._server_writer, {
|
||||||
'type': MessageType.REQUEST_FILE_LOCATION,
|
'type': MessageType.REQUEST_FILE_LOCATION,
|
||||||
'filename': file
|
'filename': file
|
||||||
})
|
})
|
||||||
# wait until reply is ready
|
|
||||||
fileinfo, chunkinfo = self._download_results[file].get()
|
message = await self._read_message(self._server_reader)
|
||||||
totalchunknum = math.ceil(fileinfo['size'] / Peer._CHUNK_SIZE)
|
assert MessageType(message['type']) == MessageType.REPLY_FILE_LOCATION
|
||||||
|
fileinfo, chunkinfo = message['fileinfo'], message['chunkinfo']
|
||||||
logger.debug('{}: {} ==> {}'.format(file, fileinfo, chunkinfo))
|
logger.debug('{}: {} ==> {}'.format(file, fileinfo, chunkinfo))
|
||||||
|
|
||||||
|
totalchunknum = math.ceil(fileinfo['size'] / Peer._CHUNK_SIZE)
|
||||||
|
|
||||||
# TODO: decide which peer to request chunk
|
# TODO: decide which peer to request chunk
|
||||||
peers = {}
|
peers = {}
|
||||||
try:
|
# TODO: make it parallel
|
||||||
for chunknum in range(totalchunknum):
|
for chunknum in range(totalchunknum):
|
||||||
for peer_address, possessed_chunks in chunkinfo.items():
|
for peer_address, possessed_chunks in chunkinfo.items():
|
||||||
if chunknum in possessed_chunks:
|
if chunknum in possessed_chunks:
|
||||||
if peer_address not in peers:
|
if peer_address not in peers:
|
||||||
# peer_address is a string, since JSON requires keys being strings
|
# peer_address is a string, since JSON requires keys being strings
|
||||||
peers[peer_address] = self._connect(json.loads(peer_address))
|
peers[peer_address] = await asyncio.open_connection(*json.loads(peer_address), loop=self._loop)
|
||||||
# write the message to ask the chunk
|
# write the message to ask the chunk
|
||||||
self._write_message(peers[peer_address], {
|
await self._write_message(peers[peer_address][1], {
|
||||||
'type': MessageType.PEER_REQUEST_CHUNK,
|
'type': MessageType.PEER_REQUEST_CHUNK,
|
||||||
'filename': file,
|
'filename': file,
|
||||||
'chunknum': chunknum
|
'chunknum': chunknum
|
||||||
})
|
})
|
||||||
break
|
break
|
||||||
|
|
||||||
# TODO: update chunkinfo after receiving each chunk
|
# TODO: update chunkinfo after receiving each chunk
|
||||||
with open(destination + '.temp', 'wb') as dest_file:
|
with open(destination + '.temp', 'wb') as dest_file:
|
||||||
self._file_map[file] = destination
|
self._file_map[file] = destination
|
||||||
for i in range(totalchunknum):
|
for i in range(totalchunknum):
|
||||||
number, data, digest = self._download_results[file].get()
|
for address, (reader, _) in peers:
|
||||||
raw_data = pybase64.b64decode(data.encode('utf-8'), validate=True)
|
assert isinstance(reader, asyncio.StreamReader)
|
||||||
# TODO: handle if corrupted
|
while not reader.at_eof():
|
||||||
if Peer._HASH_FUNC(raw_data).hexdigest() != digest:
|
message = await self._read_message(reader)
|
||||||
assert False
|
number, data, digest = message['chunknum'], message['data'], message['digest']
|
||||||
dest_file.seek(number * Peer._CHUNK_SIZE, 0)
|
raw_data = pybase64.b64decode(data.encode('utf-8'), validate=True)
|
||||||
dest_file.write(raw_data)
|
# TODO: handle if corrupted
|
||||||
dest_file.flush()
|
if Peer._HASH_FUNC(raw_data).hexdigest() != digest:
|
||||||
# send request chunk register to server
|
assert False
|
||||||
self._write_message(self._server_sock, {
|
dest_file.seek(number * Peer._CHUNK_SIZE, 0)
|
||||||
'type': MessageType.REQUEST_CHUNK_REGISTER,
|
dest_file.write(raw_data)
|
||||||
'filename': file,
|
dest_file.flush()
|
||||||
'chunknum': number
|
# send request chunk register to server
|
||||||
})
|
await self._write_message(self._server_writer, {
|
||||||
if reporthook:
|
'type': MessageType.REQUEST_CHUNK_REGISTER,
|
||||||
reporthook(i + 1, Peer._CHUNK_SIZE, fileinfo['size'])
|
'filename': file,
|
||||||
logger.debug('Got {}\'s chunk # {}'.format(file, number))
|
'chunknum': number
|
||||||
|
})
|
||||||
|
if reporthook:
|
||||||
|
reporthook(i + 1, Peer._CHUNK_SIZE, fileinfo['size'])
|
||||||
|
logger.debug('Got {}\'s chunk # {}'.format(file, number))
|
||||||
|
|
||||||
# change the temp file into the actual file
|
# change the temp file into the actual file
|
||||||
os.rename(destination + '.temp', destination)
|
os.rename(destination + '.temp', destination)
|
||||||
|
|
||||||
with self._download_lock:
|
# close the connections
|
||||||
del self._download_results[file]
|
for _, (_, writer) in peers:
|
||||||
|
writer.close()
|
||||||
finally:
|
await writer.wait_closed()
|
||||||
# close the sockets no matter what happens
|
|
||||||
for _, client in peers.items():
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
return True, 'File {} dowloaded to {}'.format(file, destination)
|
return True, 'File {} dowloaded to {}'.format(file, destination)
|
||||||
|
|
||||||
def _process_message(self, client, message):
|
async def _process_connection(self, reader, writer):
|
||||||
if message['type'] == MessageType.REPLY_REGISTER:
|
assert isinstance(reader, asyncio.StreamReader) and isinstance(writer, asyncio.StreamWriter)
|
||||||
logger.info('Successfully registered.')
|
while not reader.at_eof():
|
||||||
elif message['type'] == MessageType.REPLY_PUBLISH:
|
message = await self._read_message(reader)
|
||||||
self._publish_results[message['filename']].put((message['result'], message['message']))
|
message_type = MessageType(message['type'])
|
||||||
elif message['type'] == MessageType.REPLY_FILE_LIST:
|
if message_type == MessageType.PEER_REQUEST_CHUNK:
|
||||||
self._file_list_result.put(message['file_list'])
|
assert message['filename'] in self._file_map, 'File {} requested does not exist'.format(message['filename'])
|
||||||
elif message['type'] == MessageType.REPLY_FILE_LOCATION:
|
local_file = self._file_map[message['filename']]
|
||||||
self._download_results[message['filename']].put((message['fileinfo'], message['chunkinfo']))
|
with open(local_file, 'rb') as f:
|
||||||
elif message['type'] == MessageType.PEER_REQUEST_CHUNK:
|
f.seek(message['chunknum'] * Peer._CHUNK_SIZE, 0)
|
||||||
assert message['filename'] in self._file_map, 'File {} requested does not exist'.format(message['filename'])
|
raw_data = f.read(Peer._CHUNK_SIZE)
|
||||||
local_file = self._file_map[message['filename']]
|
await self._write_message(writer, {
|
||||||
with open(local_file, 'rb') as f:
|
'type': MessageType.PEER_REPLY_CHUNK,
|
||||||
f.seek(message['chunknum'] * Peer._CHUNK_SIZE, 0)
|
'filename': message['filename'],
|
||||||
raw_data = f.read(Peer._CHUNK_SIZE)
|
'chunknum': message['chunknum'],
|
||||||
self._write_message(client, {
|
'data': pybase64.b64encode(raw_data).decode('utf-8'),
|
||||||
'type': MessageType.PEER_REPLY_CHUNK,
|
'digest': Peer._HASH_FUNC(raw_data).hexdigest()
|
||||||
'filename': message['filename'],
|
})
|
||||||
'chunknum': message['chunknum'],
|
else:
|
||||||
'data': pybase64.b64encode(raw_data).decode('utf-8'),
|
logger.error('Undefined message with type {}, full message: {}'.format(message['type'], message))
|
||||||
'digest': Peer._HASH_FUNC(raw_data).hexdigest()
|
|
||||||
})
|
|
||||||
elif message['type'] == MessageType.PEER_REPLY_CHUNK:
|
|
||||||
self._download_results[message['filename']].put((message['chunknum'], message['data'], message['digest']))
|
|
||||||
else:
|
|
||||||
logger.error('Undefined message with type {}, full message: {}'.format(message['type'], message))
|
|
||||||
|
|
||||||
def _client_closed(self, client):
|
|
||||||
# TODO: hanlde client closed unexpectedly
|
|
||||||
assert isinstance(client, socket.socket)
|
|
||||||
if client is self._server_sock:
|
|
||||||
logger.error('Server {} closed unexpectedly'.format(client.getpeername()))
|
|
||||||
exit(1)
|
|
Loading…
Reference in a new issue