Use read_message and write_message function instead of methods.
Also add DownloadManager class.
This commit is contained in:
parent
e4567c45a7
commit
c9ba7fbe77
1 changed files with 89 additions and 9 deletions
|
@ -6,10 +6,90 @@ import time
|
|||
import hashlib
|
||||
import asyncio
|
||||
import pybase64
|
||||
from p2pfs.core.server import MessageServer, MessageType
|
||||
from p2pfs.core.message import MessageType, read_message, write_message
|
||||
from p2pfs.core.server import MessageServer
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadManager:
|
||||
def __init__(self, tracker_reader, tracker_writer, filename, server_address):
|
||||
self._tracker_reader = tracker_reader
|
||||
self._tracker_writer = tracker_writer
|
||||
self._filename = filename
|
||||
self._server_address = server_address
|
||||
|
||||
self._chunkinfo = {}
|
||||
# peers and their read tasks
|
||||
# peer_address -> [reader, writer, RTT]
|
||||
self._peers = {}
|
||||
self._read_tasks = {}
|
||||
|
||||
# indicating the tracker's connectivity
|
||||
self._is_connected = True
|
||||
|
||||
async def _test_peer_rtt(self, address):
|
||||
reader, writer = self._peers[address]
|
||||
self._peers[address][2] = time.time()
|
||||
await write_message(writer, {
|
||||
'type': MessageType.PEER_PING_PONG,
|
||||
'peer_address': address
|
||||
})
|
||||
await read_message(reader)
|
||||
self._peers[address][2] = time.time() - self._peers[address][2]
|
||||
|
||||
async def _test_multi_peer_rtt(self, addresses):
|
||||
""" Test multiple peer's rtt, must have registered in _peers"""
|
||||
read_coros = set()
|
||||
for address in addresses:
|
||||
reader, writer, _ = self._peers[address]
|
||||
# send out ping packet
|
||||
await write_message(writer, {
|
||||
'type': MessageType.PEER_PING_PONG,
|
||||
'peer_address': address
|
||||
})
|
||||
# register read task
|
||||
read_coros.add(read_message(reader))
|
||||
# set current time
|
||||
self._peers[address][2] = time.time()
|
||||
# start reading from peers to get pong packets
|
||||
for done in asyncio.as_completed({asyncio.ensure_future(read_coro) for read_coro in read_coros}):
|
||||
message = await done
|
||||
address = message['peer_address']
|
||||
self._peers[address][2] = time.time() - self._peers[address][2]
|
||||
|
||||
async def _request_chunkinfo(self, filename):
|
||||
await write_message(self._tracker_writer, {
|
||||
'type': MessageType.REQUEST_FILE_LOCATION,
|
||||
'filename': filename
|
||||
})
|
||||
|
||||
message = await read_message(self._tracker_reader)
|
||||
assert MessageType(message['type']) == MessageType.REPLY_FILE_LOCATION
|
||||
fileinfo, chunkinfo = message['fileinfo'], message['chunkinfo']
|
||||
logger.debug('{}: {} ==> {}'.format(filename, fileinfo, chunkinfo))
|
||||
# cancel out self registration
|
||||
if json.dumps(self._server_address) in chunkinfo:
|
||||
del chunkinfo[json.dumps(self._server_address)]
|
||||
return fileinfo, chunkinfo
|
||||
|
||||
async def update_chunkinfo(self):
|
||||
if not self._is_connected:
|
||||
return
|
||||
|
||||
async def download(self):
|
||||
pass
|
||||
|
||||
async def clean(self):
|
||||
# cancel current reading tasks
|
||||
for task in self._read_tasks.keys():
|
||||
task.cancel()
|
||||
# close the connections
|
||||
for _, (_, writer) in self._peers.items():
|
||||
if not writer.is_closing():
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
|
||||
class Peer(MessageServer):
|
||||
_CHUNK_SIZE = 512 * 1024
|
||||
_HASH_FUNC = hashlib.sha256
|
||||
|
@ -34,7 +114,7 @@ class Peer(MessageServer):
|
|||
# tracker disconnects suddenly
|
||||
try:
|
||||
await self._tracker_writer.drain()
|
||||
except ConnectionResetError:
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
can_write = False
|
||||
if not self._tracker_writer.is_closing():
|
||||
self._tracker_writer.close()
|
||||
|
@ -54,11 +134,11 @@ class Peer(MessageServer):
|
|||
return False, 'Server connection refused!'
|
||||
# send out register message
|
||||
logger.info('Requesting to register')
|
||||
await self._write_message(self._tracker_writer, {
|
||||
await write_message(self._tracker_writer, {
|
||||
'type': MessageType.REQUEST_REGISTER,
|
||||
'address': self._server_address
|
||||
})
|
||||
message = await self._read_message(self._tracker_reader)
|
||||
message = await read_message(self._tracker_reader)
|
||||
assert MessageType(message['type']) == MessageType.REPLY_REGISTER
|
||||
logger.info('Successfully registered.')
|
||||
return True, 'Connected!'
|
||||
|
@ -102,7 +182,7 @@ class Peer(MessageServer):
|
|||
self._pending_publish.add(remote_name)
|
||||
|
||||
# send out the request packet
|
||||
await self._write_message(self._tracker_writer, {
|
||||
await write_message(self._tracker_writer, {
|
||||
'type': MessageType.REQUEST_PUBLISH,
|
||||
'filename': remote_name,
|
||||
'fileinfo': {
|
||||
|
@ -111,7 +191,7 @@ class Peer(MessageServer):
|
|||
},
|
||||
})
|
||||
|
||||
message = await self._read_message(self._tracker_reader)
|
||||
message = await read_message(self._tracker_reader)
|
||||
assert MessageType(message['type']) == MessageType.REPLY_PUBLISH
|
||||
is_success, message = message['result'], message['message']
|
||||
|
||||
|
@ -127,10 +207,10 @@ class Peer(MessageServer):
|
|||
async def list_file(self):
|
||||
if not await self.is_connected():
|
||||
return None, 'Not connected, try \'connect <tracker_ip> <tracker_port>\''
|
||||
await self._write_message(self._tracker_writer, {
|
||||
await write_message(self._tracker_writer, {
|
||||
'type': MessageType.REQUEST_FILE_LIST,
|
||||
})
|
||||
message = await self._read_message(self._tracker_reader)
|
||||
message = await read_message(self._tracker_reader)
|
||||
assert MessageType(message['type']) == MessageType.REPLY_FILE_LIST
|
||||
return message['file_list'], 'Success'
|
||||
|
||||
|
@ -274,7 +354,7 @@ class Peer(MessageServer):
|
|||
'filename': file,
|
||||
'chunknum': number
|
||||
})
|
||||
except ConnectionResetError:
|
||||
except (ConnectionResetError, RuntimeError, BrokenPipeError):
|
||||
# stop querying tracker
|
||||
assert not await self.is_connected()
|
||||
pass
|
||||
|
|
Loading…
Reference in a new issue