Use read_message and write_message function instead of methods.

Also add DownloadManager class.
This commit is contained in:
Yuxin Wang 2018-10-11 23:08:00 -04:00
parent e4567c45a7
commit c9ba7fbe77

View file

@ -6,10 +6,90 @@ import time
import hashlib
import asyncio
import pybase64
from p2pfs.core.server import MessageServer, MessageType
from p2pfs.core.message import MessageType, read_message, write_message
from p2pfs.core.server import MessageServer
logger = logging.getLogger(__name__)
class DownloadManager:
def __init__(self, tracker_reader, tracker_writer, filename, server_address):
self._tracker_reader = tracker_reader
self._tracker_writer = tracker_writer
self._filename = filename
self._server_address = server_address
self._chunkinfo = {}
# peers and their read tasks
# peer_address -> [reader, writer, RTT]
self._peers = {}
self._read_tasks = {}
# indicating the tracker's connectivity
self._is_connected = True
async def _test_peer_rtt(self, address):
reader, writer = self._peers[address]
self._peers[address][2] = time.time()
await write_message(writer, {
'type': MessageType.PEER_PING_PONG,
'peer_address': address
})
await read_message(reader)
self._peers[address][2] = time.time() - self._peers[address][2]
async def _test_multi_peer_rtt(self, addresses):
""" Test multiple peer's rtt, must have registered in _peers"""
read_coros = set()
for address in addresses:
reader, writer, _ = self._peers[address]
# send out ping packet
await write_message(writer, {
'type': MessageType.PEER_PING_PONG,
'peer_address': address
})
# register read task
read_coros.add(read_message(reader))
# set current time
self._peers[address][2] = time.time()
# start reading from peers to get pong packets
for done in asyncio.as_completed({asyncio.ensure_future(read_coro) for read_coro in read_coros}):
message = await done
address = message['peer_address']
self._peers[address][2] = time.time() - self._peers[address][2]
async def _request_chunkinfo(self, filename):
await write_message(self._tracker_writer, {
'type': MessageType.REQUEST_FILE_LOCATION,
'filename': filename
})
message = await read_message(self._tracker_reader)
assert MessageType(message['type']) == MessageType.REPLY_FILE_LOCATION
fileinfo, chunkinfo = message['fileinfo'], message['chunkinfo']
logger.debug('{}: {} ==> {}'.format(filename, fileinfo, chunkinfo))
# cancel out self registration
if json.dumps(self._server_address) in chunkinfo:
del chunkinfo[json.dumps(self._server_address)]
return fileinfo, chunkinfo
async def update_chunkinfo(self):
if not self._is_connected:
return
async def download(self):
pass
async def clean(self):
# cancel current reading tasks
for task in self._read_tasks.keys():
task.cancel()
# close the connections
for _, (_, writer) in self._peers.items():
if not writer.is_closing():
writer.close()
await writer.wait_closed()
class Peer(MessageServer):
_CHUNK_SIZE = 512 * 1024
_HASH_FUNC = hashlib.sha256
@ -34,7 +114,7 @@ class Peer(MessageServer):
# tracker disconnects suddenly
try:
await self._tracker_writer.drain()
except ConnectionResetError:
except (ConnectionResetError, BrokenPipeError):
can_write = False
if not self._tracker_writer.is_closing():
self._tracker_writer.close()
@ -54,11 +134,11 @@ class Peer(MessageServer):
return False, 'Server connection refused!'
# send out register message
logger.info('Requesting to register')
await self._write_message(self._tracker_writer, {
await write_message(self._tracker_writer, {
'type': MessageType.REQUEST_REGISTER,
'address': self._server_address
})
message = await self._read_message(self._tracker_reader)
message = await read_message(self._tracker_reader)
assert MessageType(message['type']) == MessageType.REPLY_REGISTER
logger.info('Successfully registered.')
return True, 'Connected!'
@ -102,7 +182,7 @@ class Peer(MessageServer):
self._pending_publish.add(remote_name)
# send out the request packet
await self._write_message(self._tracker_writer, {
await write_message(self._tracker_writer, {
'type': MessageType.REQUEST_PUBLISH,
'filename': remote_name,
'fileinfo': {
@ -111,7 +191,7 @@ class Peer(MessageServer):
},
})
message = await self._read_message(self._tracker_reader)
message = await read_message(self._tracker_reader)
assert MessageType(message['type']) == MessageType.REPLY_PUBLISH
is_success, message = message['result'], message['message']
@ -127,10 +207,10 @@ class Peer(MessageServer):
async def list_file(self):
if not await self.is_connected():
return None, 'Not connected, try \'connect <tracker_ip> <tracker_port>\''
await self._write_message(self._tracker_writer, {
await write_message(self._tracker_writer, {
'type': MessageType.REQUEST_FILE_LIST,
})
message = await self._read_message(self._tracker_reader)
message = await read_message(self._tracker_reader)
assert MessageType(message['type']) == MessageType.REPLY_FILE_LIST
return message['file_list'], 'Success'
@ -274,7 +354,7 @@ class Peer(MessageServer):
'filename': file,
'chunknum': number
})
except ConnectionResetError:
except (ConnectionResetError, RuntimeError, BrokenPipeError):
# stop querying tracker
assert not await self.is_connected()
pass