Use Queue() for synchronization.

This commit is contained in:
Yuxin Wang 2018-09-26 21:11:52 -04:00
parent 01dd79b974
commit bb06947048
2 changed files with 35 additions and 36 deletions

View file

@ -4,6 +4,7 @@ import socket
import logging
import os.path
import threading
from queue import Queue
logger = logging.getLogger(__name__)
@ -14,13 +15,18 @@ class Peer(MessageServer):
# (remote filename) <-> (local filename)
self._file_map = {}
# locks and results for publish method
self._publish_locks = {}
# lock and results for publish method
self._publish_lock = threading.Lock()
self._publish_results = {}
# locks and results for list_file method
self._list_file_lock = threading.Lock()
self._file_list = {}
# lock and results for list_file method
self._file_list = None
self._file_list_lock = threading.Lock()
self._file_list_result = Queue()
# lock and results for download
self._download_lock = threading.Lock()
self._download_results = {}
# socket connected to server
try:
@ -30,53 +36,41 @@ class Peer(MessageServer):
exit(1)
def publish(self, file):
# TODO: this method is not thread-safe, the following line only prevents sequential re-entrant
# but 2 threads can both pass the condition check and the lock-twice trick will possibly cause a deadlock
if file in self._publish_locks:
return False, 'Publish file {} already in progress.'.format(file)
path, filename = os.path.split(file)
# guard the check to prevent 2 threads passing the check simultaneously
with self._publish_lock:
if filename in self._publish_results:
return False, 'Publish file {} already in progress.'.format(file)
self._publish_results[filename] = Queue(maxsize=1)
if not os.path.exists(file):
return False, 'File {} doesn\'t exist'.format(file)
# send out the request packet
path, filename = os.path.split(file)
self._write_message(self._server_sock, {
'type': MessageType.REQUEST_PUBLISH,
'filename': filename,
'size': os.stat(file).st_size
})
# we need to lock twice
lock = threading.Lock()
self._publish_locks[filename] = lock
# first lock acquires the resource
lock.acquire()
# second lock blocks the method until reply is ready
lock.acquire()
# when we wake up it means the result is ready
is_success, message = self._publish_results[filename]
# queue will block until the result is ready
is_success, message = self._publish_results[filename].get()
if is_success:
self._file_map[filename] = file
logger.info('File {} published on server with name {}'.format(file, filename))
else:
logger.info('File {} failed to publish, {}'.format(file, message))
# remove the locks and results
del self._publish_locks[filename]
del self._publish_results[filename]
# just one release here, the other release is called within _process_message method when
# reply is ready
lock.release()
# remove result
with self._publish_lock:
del self._publish_results[filename]
return is_success, message
def list_file(self):
# TODO: lock-twice trick is not thread-safe
self._write_message(self._server_sock, {
'type': MessageType.REQUEST_FILE_LIST,
})
# same technique of publish method
self._list_file_lock.acquire()
self._list_file_lock.acquire()
self._list_file_lock.release()
with self._file_list_lock:
self._file_list = self._file_list_result.get()
return self._file_list
def download(self, file, destination, progress):
@ -120,11 +114,11 @@ class Peer(MessageServer):
assert client in self._peers
self._peers[client] = message['id']
elif message['type'] == MessageType.REPLY_PUBLISH:
self._publish_results[message['filename']] = (message['result'], message['message'])
self._publish_locks[message['filename']].release()
self._publish_results[message['filename']].put((message['result'], message['message']))
elif message['type'] == MessageType.REPLY_FILE_LIST:
self._file_list = message['file_list']
self._list_file_lock.release()
self._file_list_result.put(message['file_list'])
elif message['type'] == MessageType.REPLY_FILE_LOCATION:
self._download_results[message['filename']].put(message['chunkinfo'])
else:
logger.error('Undefined message with type {}, full message: {}'.format(message['type'], message))

View file

@ -52,7 +52,12 @@ class PeerTerminal(cmd.Cmd):
print(self._peer.list_file())
def do_download(self, arg):
# TODO: implement download function
pass
filename, destionation, *_ = arg.split(' ')
def progress(current, total):
print('{} / {}, {}%'.format(current, total, int(current * 100 / total)))
self._peer.download(filename, destionation, progress)
def do_exit(self, arg):
self._peer.exit()