diff --git a/syncplay/network_utils.py b/syncplay/network_utils.py new file mode 100644 index 0000000..98eae3d --- /dev/null +++ b/syncplay/network_utils.py @@ -0,0 +1,49 @@ +#coding:utf8 + +from twisted.protocols.basic import LineReceiver + +class CommandProtocol(LineReceiver): + states = None + + def __init__(self): + self._state = self.initial_state + + def lineReceived(self, line): + line = line.strip() + if not line: + return + line = line.split(None, 1) + if len(line) != 2: + self.drop_with_error('Malformed line') + return + command, arg = line + + available_commands = self.states.get(self._state) + handler = available_commands.get(command) + if handler: + handler = getattr(self, handler, None) + if not handler: + self.drop_with_error('Unknown command: `%s`' % command) + return # TODO log it too + + handler(arg) + + def change_state(self, new_state): + if new_state not in self.states: + raise RuntimeError('Unknown state: %s' % new_state) + self._state = new_state + + def send_message(self, *args): + self.sendLine(' '.join( + (arg if isinstance(arg, basestring) else str(arg)) + for arg in args + if arg is not None + )) + + def drop(self): + self.transport.loseConnection() + + def drop_with_error(self, error): + self.send_message('error', error) + self.drop() + diff --git a/syncplay/server.py b/syncplay/server.py index 70675ec..d55b580 100644 --- a/syncplay/server.py +++ b/syncplay/server.py @@ -3,80 +3,37 @@ import time from twisted.internet.protocol import Factory -from twisted.protocols.basic import LineReceiver -class SyncProtocol(LineReceiver): +from .network_utils import CommandProtocol + +class SyncServerProtocol(CommandProtocol): def __init__(self, factory): - self._factory = factory + CommandProtocol.__init__(self) - self._state = 'init' - self._active = False - - def connectionMade(self): - self._active = True + self.factory = factory def connectionLost(self, reason): - self._active = False - self._factory.remove_watcher(self) + self.factory.remove_watcher(self) - def lineReceived(self, line): - line = line.strip() - if not line: - return - line = line.split(None, 1) - if len(line) != 2: - self._drop_with_error('Malformed line') - return - command, arg = line - - available_commands = self.states.get(self._state) - if not available_commands: - return # TODO log it - - handler = available_commands.get(command) - if handler: - handler = getattr(self, handler, None) - if not handler: - self._drop_with_error('Unknown command: `%s`' % command) - return # TODO log it too - - handler(arg) - - - def _get_ident(self): + def get_ident(self): return '|'.join(( self.transport.getPeer().host, str(id(self)), )) - def _send(self, *args): - self.sendLine(' '.join( - (arg if isinstance(arg, basestring) else str(arg)) - for arg in args - if arg is not None - )) + def handle_init_iam(self, arg): + self.factory.add_watcher(self, arg.strip()) + self.change_state('connected') - def _drop(self): - self._active = False - self.transport.loseConnection() - - def _drop_with_error(self, error): - self._send('error', error) - self._drop() - - def _handle_init_iam(self, arg): - self._factory.add_watcher(self, arg.strip()) - self._state = 'connected' - - def _handle_connected_state(self, arg): + def handle_connected_state(self, arg): arg = arg.split(None, 1) if len(arg) != 2: - self._drop_with_error('Malformed state attributes') + self.drop_with_error('Malformed state attributes') return state, position = arg if not state in ('paused', 'playing'): - self._drop_with_error('Unknown state') + self.drop_with_error('Unknown state') return paused = state == 'paused' @@ -84,43 +41,44 @@ class SyncProtocol(LineReceiver): try: position = int(position) except ValueError: - self._drop_with_error('Invalid position numeral') + self.drop_with_error('Invalid position numeral') position /= 100.0 - self._factory.update_state(self, paused, position) + self.factory.update_state(self, paused, position) - def _handle_connected_seek(self, arg): + def handle_connected_seek(self, arg): try: position = int(arg) except ValueError: - self._drop_with_error('Invalid position numeral') + self.drop_with_error('Invalid position numeral') position /= 100.0 - self._factory.seek(self, position) + self.factory.seek(self, position) def __hash__(self): - return hash(self._get_ident()) + return hash(self.get_ident()) def send_state(self, paused, position, who_last_changed): - self._send('state', ('paused' if paused else 'playing'), int(position*100), who_last_changed) + self.send_message('state', ('paused' if paused else 'playing'), int(position*100), who_last_changed) def send_seek(self, position, who_seeked): - self._send('seek', int(position*100), who_seeked) + self.send_message('seek', int(position*100), who_seeked) states = dict( init = dict( - iam = '_handle_init_iam', + iam = 'handle_init_iam', ), connected = dict( - state = '_handle_connected_state', - seek = '_handle_connected_seek', - #ping = '_handle_connected_ping', + state = 'handle_connected_state', + seek = 'handle_connected_seek', + #ping = 'handle_connected_ping', ), ) + initial_state = 'init' class WatcherInfo(object): @@ -151,13 +109,13 @@ class SyncFactory(Factory): self.update_time_limit = update_time_limit def buildProtocol(self, addr): - return SyncProtocol(self) + return SyncServerProtocol(self) def add_watcher(self, watcher_proto, name): watcher = WatcherInfo(watcher_proto, name) self.watchers[watcher_proto] = watcher - self._send_state_to(watcher) + self.send_state_to(watcher) # send info someone joined def remove_watcher(self, watcher_proto): @@ -186,23 +144,23 @@ class SyncFactory(Factory): else: pause_changed = False - position = self._find_position() + position = self.find_position() for receiver in self.watchers.itervalues(): if ( receiver == watcher or pause_changed or (curtime-receiver.last_update_sent) > self.update_time_limit ): - self._send_state_to(receiver, position, curtime) + self.send_state_to(receiver, position, curtime) def seek(self, watcher_proto, position): #TODO #for receiver in self.watchers.itervalues(): pass - def _send_state_to(self, watcher, position=None, curtime=None): + def send_state_to(self, watcher, position=None, curtime=None): if position is None: - position = self._find_position() + position = self.find_position() if curtime is None: curtime = time.time() if self.pause_change_by: @@ -211,7 +169,7 @@ class SyncFactory(Factory): watcher.watcher_proto.send_state(self.paused, position, None) watcher.last_update_sent = curtime - def _find_position(self): + def find_position(self): curtime = time.time() try: return min(