Source code for poco.utils.net.transport.tcp

# coding=utf-8

import errno
import socket
import select
import uuid

from poco.sdk.std.transport import Transport
from poco.sdk.std.protocol import SimpleProtocolFilter
from poco.utils import six

if six.PY3:
    from queue import Queue, Empty
else:
    from Queue import Queue, Empty


[docs]class ConnectionReset(Exception): pass
class Connection(object): def __init__(self, cid, sock, endpoint, RX_SIZE=65536): super(Connection, self).__init__() self.cid = cid self.sock = sock self.endpoint = endpoint self.p = SimpleProtocolFilter() self.RX_SIZE = RX_SIZE def send(self, packet): data = self.p.pack(packet) self.sock.sendall(data) def recv(self): rxdata = '' try: rxdata = self.sock.recv(self.RX_SIZE) except socket.error as e: if e.errno in (errno.ECONNRESET, ): raise ConnectionReset if not rxdata: self.close() raise ConnectionReset else: for packet in self.p.input(rxdata): yield packet def close(self): try: self.sock.close() except socket.error: pass def get_socket_object(self): return self.sock class TcpSocket(Transport): def __init__(self, RX_SIZE=65536): super(TcpSocket, self).__init__() # active socket object self.s = None self.connections = {} # sock -> Connection self.connections_endpoints = {} # endpoint -> Connection self.rq = Queue() self.RX_SIZE = RX_SIZE def connect(self, endpoint): if endpoint in self.connections_endpoints: raise RuntimeError("Already connected to {}".format(endpoint)) c = socket.socket() c.connect(endpoint) cid = six.text_type(uuid.uuid4()) conn = Connection(cid, c, endpoint, self.RX_SIZE) self.connections[c] = conn self.connections_endpoints[endpoint] = conn def disconnect(self, endpoint=None): if endpoint is not None: conn = self.connections_endpoints.pop(endpoint, None) if conn: self.connections.pop(conn.get_socket_object(), None) conn.close() else: for _, conn in self.connections_endpoints.items(): conn.close() self.connections_endpoints = {} self.connections = {} def bind(self, endpoint): if self.s is not None: raise RuntimeError("Already bound at {}".format(self.s.getsockname())) ip, port = endpoint if ip in ('', '*', '0'): ip = '0.0.0.0' self.s = socket.socket() self.s.bind((ip, port)) self.s.listen(10) print('server listens on ("{}", {}) transport socket'.format(ip, port)) def update(self, timeout=0.002): rlist = [] if self.s is not None: rlist.append(self.s) rlist.extend(self.connections.keys()) r, _, _ = select.select(rlist, [], [], timeout) for c in r: if c is self.s: client_sock, endpoint = self.s.accept() print('accept from: {}'.format(endpoint)) cid = six.text_type(uuid.uuid4()) conn = Connection(cid, client_sock, endpoint, self.RX_SIZE) self.connections[client_sock] = conn self.connections_endpoints[endpoint] = conn else: conn = self.connections.get(c) if not conn: continue try: for packet in conn.recv(): self.rq.put((conn.cid, packet)) except ConnectionReset: self.connections.pop(c) return self.recv() def recv(self): try: return self.rq.get(False) except Empty: return None, None def send(self, cid, packet): if cid is None: # broadcast for _, conn in self.connections.items(): conn.send(packet) else: conn = self.get_connection(cid) if conn: conn.send(packet) def get_connection(self, cid): for sock, conn in self.connections.items(): if conn.cid == cid: return conn return None def __str__(self): return 'Tcp connection(s) at {}'.format(self.connections_endpoints.keys()) __repr__ = __str__