diff --git a/syncplay/client.py b/syncplay/client.py index aa4645e..834e52c 100644 --- a/syncplay/client.py +++ b/syncplay/client.py @@ -5,7 +5,10 @@ import time from twisted.internet import reactor from twisted.internet.protocol import ClientFactory -from .network_utils import CommandProtocol +from .network_utils import ( + arg_count, + CommandProtocol, +) from .utils import parse_state @@ -26,6 +29,7 @@ class SyncClientProtocol(CommandProtocol): self.manager.stop() CommandProtocol.handle_error(self, args) + @arg_count(3, 4) def handle_connected_state(self, args): args = parse_state(args) if not args: @@ -36,10 +40,8 @@ class SyncClientProtocol(CommandProtocol): self.manager.update_global_state(counter, paused, position, name) + @arg_count(1) def handle_connected_ping(self, args): - if not len(args) == 1: - self.drop_with_error('Invalid arguments') - return self.send_message('pong', args[0]) def send_state(self, counter, paused, position): @@ -162,8 +164,9 @@ class Manager(object): self.protocol = protocol self.schedule_send_status() self.send_filename() - if self.player is None: + if self.make_player: self.make_player(self) + self.make_player = None def schedule_ask_player(self, when=0.2): diff --git a/syncplay/network_utils.py b/syncplay/network_utils.py index 1c58632..014cd23 100644 --- a/syncplay/network_utils.py +++ b/syncplay/network_utils.py @@ -5,6 +5,8 @@ try: except ImportError: from StringIO import StringIO +from functools import wraps + from twisted.internet.defer import succeed from twisted.internet.protocol import ( ProcessProtocol, @@ -20,6 +22,17 @@ from .utils import ( split_args, ) +def arg_count(minimum, maximum=None): + def decorator(f): + @wraps(f) + def wrapper(self, args): + if ((len(args) != minimum) if maximum is None else not (minimum <= len(args) <= maximum)): + self.drop_with_error('Invalid arguments') + return + return f(self, args) + return wrapper + return decorator + class CommandProtocol(LineReceiver): states = None diff --git a/syncplay/server.py b/syncplay/server.py index c25be65..c12667b 100644 --- a/syncplay/server.py +++ b/syncplay/server.py @@ -6,7 +6,10 @@ import random from twisted.internet import reactor from twisted.internet.protocol import Factory -from .network_utils import CommandProtocol +from .network_utils import ( + arg_count, + CommandProtocol, +) from .utils import parse_state random.seed() @@ -32,6 +35,7 @@ class SyncServerProtocol(CommandProtocol): self.factory.add_watcher(self, args[0]) self.change_state('connected') + @arg_count(3) def handle_connected_state(self, args): args = parse_state(args) if not args: @@ -42,11 +46,8 @@ class SyncServerProtocol(CommandProtocol): self.factory.update_state(self, counter, paused, position) + @arg_count(1) def handle_connected_seek(self, args): - if not len(args) == 1: - self.drop_with_error('Invalid arguments') - return - try: position = int(args[0]) except ValueError: @@ -56,17 +57,13 @@ class SyncServerProtocol(CommandProtocol): self.factory.seek(self, position) + @arg_count(1) def handle_connected_pong(self, args): - if not len(args) == 1: - self.drop_with_error('Invalid arguments') - return self.factory.pong_received(self, args[0]) + @arg_count(1) def handle_connected_playing(self, args): - if not len(args) == 1: - self.drop_with_error('Invalid arguments') - return - #self.factory.pong_received(self, args[0]) + pass def __hash__(self): return hash('|'.join((